Skip to content

Commit 2d9ae40

Browse files
committed
update tests
1 parent 5c58048 commit 2d9ae40

File tree

2 files changed

+9
-28
lines changed

2 files changed

+9
-28
lines changed

Diff for: model/test_net.py

-19
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,6 @@
1010
class Tests(unittest.TestCase):
1111

1212

13-
def test_weight_dropout(self):
14-
# Instantiate an LSTM net with weight dropout on
15-
# hidden-to-hidden weights
16-
model = net.AWD_LSTM(10, 20, 2)
17-
18-
# Apply drop connect with probabilit 1.0
19-
model.weight_dropout(p=1.0)
20-
new_state = model.state_dict()
21-
22-
# Test
23-
# Because dropout probability is 1
24-
# all hidden-to-hidden weights should be zero
25-
# There are two such weight matrices because we
26-
# configured a 2-layer lstm
27-
self.assertTrue(new_state['layer0.h2h.weight'].sum() == 0)
28-
self.assertTrue(new_state['layer1.h2h.weight'].sum() == 0)
29-
# Here we test that input-to-hidden weights are unaffected
30-
self.assertTrue(new_state['layer0.i2h.weight'].sum() != 0)
31-
self.assertTrue(new_state['layer1.i2h.weight'].sum() != 0)
3213

3314
def test_activation_reg(self):
3415
# make deterministe

Diff for: test_utils.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,25 @@ def test_NT_ASGD(self):
1919
# ASGD used in place of SGD optimizer when
2020
# loss increases for n succesive calls to get_optimizer
2121
self.assertEqual(nt_asgd.asgd_triggered, False)
22-
nt_asgd.get_optimizer(3, model.parameters())
22+
nt_asgd.get_optimizer(3)
2323
self.assertEqual(nt_asgd.asgd_triggered, False)
24-
nt_asgd.get_optimizer(2, model.parameters())
24+
nt_asgd.get_optimizer(2)
2525
self.assertEqual(nt_asgd.asgd_triggered, False)
26-
nt_asgd.get_optimizer(3, model.parameters())
26+
nt_asgd.get_optimizer(3)
2727
self.assertEqual(nt_asgd.asgd_triggered, False)
28-
nt_asgd.get_optimizer(3, model.parameters())
28+
nt_asgd.get_optimizer(3)
2929
self.assertEqual(nt_asgd.asgd_triggered, False)
3030
# ASGD Triggered because loss was lowest n+1 epochs ago
31-
nt_asgd.get_optimizer(4, model.parameters())
31+
nt_asgd.get_optimizer(4)
3232
self.assertEqual(nt_asgd.asgd_triggered, True)
33-
nt_asgd.get_optimizer(2, model.parameters())
33+
nt_asgd.get_optimizer(2)
3434
self.assertEqual(nt_asgd.asgd_triggered, True)
35-
nt_asgd.get_optimizer(3, model.parameters())
35+
nt_asgd.get_optimizer(3)
3636
self.assertEqual(nt_asgd.asgd_triggered, True)
37-
nt_asgd.get_optimizer(3, model.parameters())
37+
nt_asgd.get_optimizer(3)
3838
self.assertEqual(nt_asgd.asgd_triggered, True)
3939
# Doesn't un-trigger
40-
nt_asgd.get_optimizer(4, model.parameters())
40+
nt_asgd.get_optimizer(4)
4141
self.assertEqual(nt_asgd.asgd_triggered, True)
4242

4343

0 commit comments

Comments
 (0)