Skip to content

Commit

Permalink
update unit tests for CUDA/GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
taoleicn committed May 18, 2021
1 parent 49eaa9a commit 26e2579
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 10 deletions.
6 changes: 4 additions & 2 deletions test/sru/test_sru.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
)
@pytest.mark.parametrize("with_grad", [False, True])
@pytest.mark.parametrize("compat", [False, True])
def test_cell(cuda, with_grad, compat):
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("layer_norm", [False, True])
def test_cell(cuda, with_grad, compat, bidirectional, layer_norm):
torch.manual_seed(123)
if cuda:
torch.backends.cudnn.deterministic = True
Expand All @@ -30,12 +32,12 @@ def run():
rnn_hidden = 4
max_len = 4
layers = 5
bidirectional = True
encoder = sru.SRU(
embedding_size,
rnn_hidden,
layers,
bidirectional=bidirectional,
layer_norm=layer_norm,
nn_rnn_compatible_return=compat,
)
words_embeddings = torch.rand(
Expand Down
21 changes: 20 additions & 1 deletion test/sru/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,29 @@
import sru


@pytest.mark.parametrize(
"cuda",
[
False,
pytest.param(
True,
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="no cuda available"
),
),
],
)
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("rescale", [False, True])
@pytest.mark.parametrize("proj", [0, 4])
@pytest.mark.parametrize("layer_norm", [False, True])
def test_all(bidirectional, rescale, proj, layer_norm):
def test_all(cuda, bidirectional, rescale, proj, layer_norm):
eps = 1e-4
torch.manual_seed(1234)
if cuda:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

L = 16
B = 8
D = 32
Expand All @@ -18,6 +34,9 @@ def test_all(bidirectional, rescale, proj, layer_norm):
projection_size=proj,
layer_norm=layer_norm,
rescale=rescale)
if cuda:
model = model.cuda()
x = x.cuda()
model.eval()

h, c = model(x)
Expand Down
9 changes: 2 additions & 7 deletions test/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,8 @@ def sync(device: str) -> None:


@pytest.mark.skipif(not has_amp, reason='AMP not available')
@pytest.mark.parametrize(
'use_amp,fp16_recurrence', [
[False, False],
[False, False],
[True, False],
[True, True]]
)
@pytest.mark.parametrize('use_amp', [False, True])
@pytest.mark.parametrize('fp16_recurrence', [False, True])
def test_amp(use_amp: bool, fp16_recurrence: bool):
its = 20
warmup = 3
Expand Down

0 comments on commit 26e2579

Please sign in to comment.