Skip to content

Commit

Permalink
Merge pull request #148 from asappresearch/pre-post
Browse files Browse the repository at this point in the history
Pre post
  • Loading branch information
taoleicn authored Dec 17, 2020
2 parents 1569f31 + ec1566e commit 2ef3e93
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 18 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,10 @@ __pycache__
*.egg-info/
sru/csrc/build
.python-version

# created by torchscript tests:
py_out.txt
sru/csrc/Makefile
sru/csrc/tests/main_test_cpp.cpp
sru_ts.pt
cpp_out.txt
26 changes: 20 additions & 6 deletions sru/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def __init__(self,
rescale: bool = True,
v1: bool = False,
custom_m: Optional[nn.Module] = None,
amp_recurrence_fp16: bool = False):
amp_recurrence_fp16: bool = False,
normalize_after: bool = False):
"""Initialize the SRUCell module.
Parameters
Expand Down Expand Up @@ -94,6 +95,8 @@ def __init__(self,
When using AMP autocast, selects which type to use
for recurrence custom kernel.
False: torch.float32, True: torch.float16
normalize_after: bool
if True use post layer norm, else pre layer norm
"""
super(SRUCell, self).__init__()
self.input_size = input_size
Expand All @@ -113,6 +116,7 @@ def __init__(self,
self.activation_type = 1
self.activation = 'tanh'
self.amp_recurrence_fp16 = amp_recurrence_fp16
self.normalize_after = normalize_after

# projection dimension
self.projection_size = 0
Expand Down Expand Up @@ -146,7 +150,10 @@ def __init__(self,

self.layer_norm: Optional[nn.Module]= None
if layer_norm:
self.layer_norm = nn.LayerNorm(self.input_size)
if normalize_after:
self.layer_norm = nn.LayerNorm(self.output_size)
else:
self.layer_norm = nn.LayerNorm(self.input_size)

self.reset_parameters()

Expand Down Expand Up @@ -235,7 +242,7 @@ def forward(self,

# apply layer norm before activation (i.e. before SRU computation)
residual = input
if self.layer_norm is not None:
if self.layer_norm is not None and not self.normalize_after:
input = self.layer_norm(input)

# apply dropout for multiplication
Expand All @@ -260,6 +267,10 @@ def forward(self,

# apply elementwise recurrence to get hidden states h and c
h, c = self.apply_recurrence(U, V, residual, c0, scale_val, mask_c, mask_pad)

if self.layer_norm is not None and self.normalize_after:
h = self.layer_norm(h)

return h, c

def apply_recurrence(self,
Expand Down Expand Up @@ -427,7 +438,8 @@ def __init__(self,
nn_rnn_compatible_return: bool = False,
custom_m: Optional[Union[nn.Module, List[nn.Module]]] = None,
proj_input_to_hidden_first: bool = False,
amp_recurrence_fp16: bool = False):
amp_recurrence_fp16: bool = False,
normalize_after: bool = False):
"""Initialize the SRU module.
Parameters
Expand Down Expand Up @@ -486,7 +498,8 @@ def __init__(self,
When using AMP autocast, selects which type to use
for recurrence custom kernel.
False: torch.float32, True: torch.float16
normalize_after: bool
if True use post layer norm, else use pre layer norm
"""

super(SRU, self).__init__()
Expand Down Expand Up @@ -538,7 +551,8 @@ def __init__(self,
rescale=rescale,
v1=v1,
custom_m=custom_m_i,
amp_recurrence_fp16=amp_recurrence_fp16
amp_recurrence_fp16=amp_recurrence_fp16,
normalize_after=normalize_after
)
rnn_lst.append(layer_i)
self.rnn_lst = rnn_lst
Expand Down
7 changes: 6 additions & 1 deletion test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
set -e
set -x

python test/test_ts_cpp.py > py_out.txt
cd sru/csrc/
if [[ -d build ]]; then {
rm -Rf build
Expand All @@ -17,5 +16,11 @@ cd build
cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch; import os.path; print(os.path.join(os.path.dirname(torch.__file__), "share", "cmake"))')" ..
make -j
cd ../../../

python test/test_ts_cpp.py > py_out.txt
sru/csrc/build/example_app sru_ts.pt > cpp_out.txt
diff cpp_out.txt py_out.txt

python test/test_ts_cpp.py --normalize-after > py_out.txt
sru/csrc/build/example_app sru_ts.pt > cpp_out.txt
diff cpp_out.txt py_out.txt
32 changes: 21 additions & 11 deletions test/test_ts_cpp.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
import torch
import sru
import argparse

D = 4
model = sru.SRU(D, D, num_layers=2)
model.eval()

ts_model = torch.jit.script(model)
ts_model.save('sru_ts.pt')
def run(args):
D = 4
model = sru.SRU(D, D, num_layers=2, normalize_after=args.normalize_after)
model.eval()

with torch.no_grad():
x = torch.ones(3, 2, D)
h, c = model(x)
h, c = h.view(-1), c.view(-1)
print(''.join(["{:.4f} ".format(x.item()) for x in h]))
print(''.join(["{:.4f} ".format(x.item()) for x in c]))
ts_model = torch.jit.script(model)
ts_model.save('sru_ts.pt')

with torch.no_grad():
x = torch.ones(3, 2, D)
h, c = model(x)
h, c = h.view(-1), c.view(-1)
print(''.join(["{:.4f} ".format(x.item()) for x in h]))
print(''.join(["{:.4f} ".format(x.item()) for x in c]))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--normalize-after', action='store_true')
args = parser.parse_args()
run(args)

0 comments on commit 2ef3e93

Please sign in to comment.