Skip to content

Commit

Permalink
test normalize-after in torchscript tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hpasapp committed Dec 17, 2020
1 parent b8235f0 commit 265f75d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
7 changes: 6 additions & 1 deletion test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
set -e
set -x

python test/test_ts_cpp.py > py_out.txt
cd sru/csrc/
if [[ -d build ]]; then {
rm -Rf build
Expand All @@ -17,5 +16,11 @@ cd build
cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch; import os.path; print(os.path.join(os.path.dirname(torch.__file__), "share", "cmake"))')" ..
make -j
cd ../../../

python test/test_ts_cpp.py > py_out.txt
sru/csrc/build/example_app sru_ts.pt > cpp_out.txt
diff cpp_out.txt py_out.txt

python test/test_ts_cpp.py --normalize-after > py_out.txt
sru/csrc/build/example_app sru_ts.pt > cpp_out.txt
diff cpp_out.txt py_out.txt
32 changes: 21 additions & 11 deletions test/test_ts_cpp.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
import torch
import sru
import argparse

D = 4
model = sru.SRU(D, D, num_layers=2)
model.eval()

ts_model = torch.jit.script(model)
ts_model.save('sru_ts.pt')
def run(args):
D = 4
model = sru.SRU(D, D, num_layers=2, normalize_after=args.normalize_after)
model.eval()

with torch.no_grad():
x = torch.ones(3, 2, D)
h, c = model(x)
h, c = h.view(-1), c.view(-1)
print(''.join(["{:.4f} ".format(x.item()) for x in h]))
print(''.join(["{:.4f} ".format(x.item()) for x in c]))
ts_model = torch.jit.script(model)
ts_model.save('sru_ts.pt')

with torch.no_grad():
x = torch.ones(3, 2, D)
h, c = model(x)
h, c = h.view(-1), c.view(-1)
print(''.join(["{:.4f} ".format(x.item()) for x in h]))
print(''.join(["{:.4f} ".format(x.item()) for x in c]))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--normalize-after', action='store_true')
args = parser.parse_args()
run(args)

0 comments on commit 265f75d

Please sign in to comment.