@@ -19,25 +19,25 @@ def test_NT_ASGD(self):
19
19
# ASGD used in place of SGD optimizer when
20
20
# loss increases for n succesive calls to get_optimizer
21
21
self .assertEqual (nt_asgd .asgd_triggered , False )
22
- nt_asgd .get_optimizer (3 , model . parameters () )
22
+ nt_asgd .get_optimizer (3 )
23
23
self .assertEqual (nt_asgd .asgd_triggered , False )
24
- nt_asgd .get_optimizer (2 , model . parameters () )
24
+ nt_asgd .get_optimizer (2 )
25
25
self .assertEqual (nt_asgd .asgd_triggered , False )
26
- nt_asgd .get_optimizer (3 , model . parameters () )
26
+ nt_asgd .get_optimizer (3 )
27
27
self .assertEqual (nt_asgd .asgd_triggered , False )
28
- nt_asgd .get_optimizer (3 , model . parameters () )
28
+ nt_asgd .get_optimizer (3 )
29
29
self .assertEqual (nt_asgd .asgd_triggered , False )
30
30
# ASGD Triggered because loss was lowest n+1 epochs ago
31
- nt_asgd .get_optimizer (4 , model . parameters () )
31
+ nt_asgd .get_optimizer (4 )
32
32
self .assertEqual (nt_asgd .asgd_triggered , True )
33
- nt_asgd .get_optimizer (2 , model . parameters () )
33
+ nt_asgd .get_optimizer (2 )
34
34
self .assertEqual (nt_asgd .asgd_triggered , True )
35
- nt_asgd .get_optimizer (3 , model . parameters () )
35
+ nt_asgd .get_optimizer (3 )
36
36
self .assertEqual (nt_asgd .asgd_triggered , True )
37
- nt_asgd .get_optimizer (3 , model . parameters () )
37
+ nt_asgd .get_optimizer (3 )
38
38
self .assertEqual (nt_asgd .asgd_triggered , True )
39
39
# Doesn't un-trigger
40
- nt_asgd .get_optimizer (4 , model . parameters () )
40
+ nt_asgd .get_optimizer (4 )
41
41
self .assertEqual (nt_asgd .asgd_triggered , True )
42
42
43
43
0 commit comments