Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into smaller-weight-c
Browse files Browse the repository at this point in the history
# Conflicts:
#	sru/modules.py
  • Loading branch information
hpasapp committed Dec 17, 2020
2 parents 5688728 + 318ab42 commit f4e1687
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 17 deletions.
10 changes: 8 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ jobs:
command: |
virtualenv -p python3.7 .venv
source .venv/bin/activate
pip install -q torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# pip install -q torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# temporary workaround for https://github.com/pytorch/pytorch/issues/49560
wget https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
pip install -q torch-1.6.0+cpu-cp37-cp37m-linux_x86_64.whl
pip install -q .
pip install -q -r requirements-test.txt
- run:
Expand All @@ -32,7 +35,10 @@ jobs:
command: |
virtualenv -p python3.7 .venv
source .venv/bin/activate
pip install -q torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# pip install -q torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# temporary workaround for https://github.com/pytorch/pytorch/issues/49560
wget https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
pip install -q torch-1.6.0+cpu-cp37-cp37m-linux_x86_64.whl
pip install -q .
sudo apt-get -y install cmake
- run:
Expand Down
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,10 @@ __pycache__
*.egg-info/
sru/csrc/build
.python-version

# created by torchscript tests:
py_out.txt
sru/csrc/Makefile
sru/csrc/tests/main_test_cpp.cpp
sru_ts.pt
cpp_out.txt
20 changes: 17 additions & 3 deletions sru/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self,
v1: bool = False,
custom_m: Optional[nn.Module] = None,
amp_recurrence_fp16: bool = False,
normalize_after: bool = False,
weight_c_init: Optional[float] = None):
"""Initialize the SRUCell module.
Expand Down Expand Up @@ -95,6 +96,8 @@ def __init__(self,
When using AMP autocast, selects which type to use
for recurrence custom kernel.
False: torch.float32, True: torch.float16
normalize_after: bool
if True use post layer norm, else pre layer norm
weight_c_init: Optional[float]
if not None, then size of uniform initiatialization of weight_c
"""
Expand All @@ -116,6 +119,7 @@ def __init__(self,
self.activation_type = 1
self.activation = 'tanh'
self.amp_recurrence_fp16 = amp_recurrence_fp16
self.normalize_after = normalize_after
self.weight_c_init = weight_c_init

# projection dimension
Expand Down Expand Up @@ -150,7 +154,10 @@ def __init__(self,

self.layer_norm: Optional[nn.Module]= None
if layer_norm:
self.layer_norm = nn.LayerNorm(self.input_size)
if normalize_after:
self.layer_norm = nn.LayerNorm(self.output_size)
else:
self.layer_norm = nn.LayerNorm(self.input_size)

self.reset_parameters()

Expand Down Expand Up @@ -242,7 +249,7 @@ def forward(self,

# apply layer norm before activation (i.e. before SRU computation)
residual = input
if self.layer_norm is not None:
if self.layer_norm is not None and not self.normalize_after:
input = self.layer_norm(input)

# apply dropout for multiplication
Expand All @@ -267,6 +274,10 @@ def forward(self,

# apply elementwise recurrence to get hidden states h and c
h, c = self.apply_recurrence(U, V, residual, c0, scale_val, mask_c, mask_pad)

if self.layer_norm is not None and self.normalize_after:
h = self.layer_norm(h)

return h, c

def apply_recurrence(self,
Expand Down Expand Up @@ -435,6 +446,7 @@ def __init__(self,
custom_m: Optional[Union[nn.Module, List[nn.Module]]] = None,
proj_input_to_hidden_first: bool = False,
amp_recurrence_fp16: bool = False,
normalize_after: bool = False,
weight_c_init: Optional[float] = None):
"""Initialize the SRU module.
Expand Down Expand Up @@ -494,9 +506,10 @@ def __init__(self,
When using AMP autocast, selects which type to use
for recurrence custom kernel.
False: torch.float32, True: torch.float16
normalize_after: bool
if True use post layer norm, else use pre layer norm
weight_c_init: Optional[float]
if not None, then size of uniform initiatialization of weight_c
"""

super(SRU, self).__init__()
Expand Down Expand Up @@ -549,6 +562,7 @@ def __init__(self,
v1=v1,
custom_m=custom_m_i,
amp_recurrence_fp16=amp_recurrence_fp16,
normalize_after=normalize_after,
weight_c_init=weight_c_init,
)
rnn_lst.append(layer_i)
Expand Down
7 changes: 6 additions & 1 deletion test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
set -e
set -x

python test/test_ts_cpp.py > py_out.txt
cd sru/csrc/
if [[ -d build ]]; then {
rm -Rf build
Expand All @@ -17,5 +16,11 @@ cd build
cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch; import os.path; print(os.path.join(os.path.dirname(torch.__file__), "share", "cmake"))')" ..
make -j
cd ../../../

python test/test_ts_cpp.py > py_out.txt
sru/csrc/build/example_app sru_ts.pt > cpp_out.txt
diff cpp_out.txt py_out.txt

python test/test_ts_cpp.py --normalize-after > py_out.txt
sru/csrc/build/example_app sru_ts.pt > cpp_out.txt
diff cpp_out.txt py_out.txt
32 changes: 21 additions & 11 deletions test/test_ts_cpp.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
import torch
import sru
import argparse

D = 4
model = sru.SRU(D, D, num_layers=2)
model.eval()

ts_model = torch.jit.script(model)
ts_model.save('sru_ts.pt')
def run(args):
D = 4
model = sru.SRU(D, D, num_layers=2, normalize_after=args.normalize_after)
model.eval()

with torch.no_grad():
x = torch.ones(3, 2, D)
h, c = model(x)
h, c = h.view(-1), c.view(-1)
print(''.join(["{:.4f} ".format(x.item()) for x in h]))
print(''.join(["{:.4f} ".format(x.item()) for x in c]))
ts_model = torch.jit.script(model)
ts_model.save('sru_ts.pt')

with torch.no_grad():
x = torch.ones(3, 2, D)
h, c = model(x)
h, c = h.view(-1), c.view(-1)
print(''.join(["{:.4f} ".format(x.item()) for x in h]))
print(''.join(["{:.4f} ".format(x.item()) for x in c]))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--normalize-after', action='store_true')
args = parser.parse_args()
run(args)

0 comments on commit f4e1687

Please sign in to comment.