1
1
#!/usr/bin/env python3
2
2
# -*- coding: utf-8 -*-
3
+
3
4
import functools
4
5
import pytest
5
6
6
- from checkpoint_schedules import MixedCheckpointSchedule , Copy , \
7
- Forward , Reverse , EndForward , EndReverse , Move , StorageType
8
- from checkpoint_schedules .mixed import optimal_steps_mixed , \
9
- mixed_step_memoization
7
+ from checkpoint_schedules import (
8
+ MixedCheckpointSchedule , Copy , Forward , Reverse , EndForward , EndReverse ,
9
+ Move , StorageType )
10
+ from checkpoint_schedules .mixed import (
11
+ optimal_steps_mixed , mixed_step_memoization )
10
12
11
13
12
14
@pytest .mark .parametrize ("n, S" , [
13
15
(1 , (0 ,)),
14
16
(2 , (1 ,)),
15
17
(3 , (1 , 2 )),
16
- (10 , tuple (range (1 , 10 ))),
18
+ (5 , (2 ,)),
19
+ (10 , tuple (range (2 , 10 ))),
17
20
(100 , tuple (range (1 , 100 ))),
18
21
(250 , tuple (range (25 , 250 , 25 )))
19
22
])
20
23
def test_mixed (n , S ):
21
- cp_action_lists = []
22
-
23
24
@functools .singledispatch
24
25
def action (cp_action ):
25
26
raise TypeError ("Unexpected action" )
26
27
27
28
@action .register (Forward )
28
29
def action_forward (cp_action ):
29
30
nonlocal model_n , model_steps
30
- nonlocal store_ics , store_data
31
- store_ics = cp_action .write_ics
32
- store_data = cp_action .write_adj_deps
31
+
33
32
ics .clear ()
34
33
data .clear ()
35
34
36
35
# Start at the current location of the forward
37
- assert model_n is not None and model_n == cp_action .n0
36
+ assert model_n is not None and cp_action .n0 == model_n
37
+ # Do not advance further than the current location of the adjoint
38
+ assert cp_action .n1 <= n - model_r
38
39
39
- if store_ics :
40
- assert cp_action .storage == StorageType .DISK
41
- assert cp_schedule .uses_storage_type (StorageType .DISK )
40
+ if cp_action .write_ics :
42
41
# Advance at least two steps when storing forward restart data
43
42
assert cp_action .n1 > cp_action .n0 + 1
44
43
# Do not advance further than one step before the current location
45
44
# of the adjoint
46
45
assert cp_action .n1 < n - model_r
47
- # No data for these steps is stored
48
- assert len (ics .intersection (range (cp_action .n0 , cp_action .n1 ))) == 0 # noqa: E501
49
46
50
- if store_data :
47
+ assert cp_action .storage == StorageType .DISK
48
+ assert cp_schedule .uses_storage_type (StorageType .DISK )
49
+ ics .update (range (cp_action .n0 , cp_action .n1 ))
50
+ assert cp_action .n0 not in snapshots
51
+ snapshots [cp_action .n0 ] = (set (ics ), set (data ))
52
+
53
+ if cp_action .write_adj_deps :
51
54
# Advance exactly one step when storing non-linear dependency data
52
55
assert cp_action .n1 == cp_action .n0 + 1
53
56
# Do not advance further than the current location of the adjoint
54
57
assert cp_action .n1 <= n - model_r
55
- # No data for this step is stored
56
- assert len (data .intersection (range (cp_action .n0 , cp_action .n1 ))) == 0 # noqa: E501
57
-
58
- model_n = cp_action .n1
59
- model_steps += cp_action .n1 - cp_action .n0
60
- if store_ics :
61
- ics .update (range (cp_action .n0 , cp_action .n1 ))
62
- snapshots [cp_action .n0 ] = (set (ics ), set (data ))
63
58
64
- if store_data :
65
59
data .update (range (cp_action .n0 , cp_action .n1 ))
66
60
if cp_action .storage == StorageType .DISK :
61
+ assert cp_schedule .uses_storage_type (StorageType .DISK )
62
+ assert cp_action .n0 not in snapshots
67
63
snapshots [cp_action .n0 ] = (set (ics ), set (data ))
64
+ else :
65
+ assert cp_action .storage == StorageType .WORK
68
66
69
- if store_ics or store_data :
70
- # Written data consists of either forward restart or non-linear
71
- # dependency data, but not both
72
- assert len (ics ) == 0 or len (data ) == 0
73
- assert len (ics ) > 0 or len (data ) > 0
74
-
75
- # Non-linear dependency data is either not stored, or is stored for a
76
- # single step
77
- assert len (data ) <= 1
67
+ # Stored data consists of either forward restart or non-linear
68
+ # dependency data, but not both
69
+ assert not cp_action .write_ics or not cp_action .write_adj_deps
78
70
79
- # The checkpoint location is associated with the earliest step for
80
- # which data has been stored
81
- if len (ics ) > 0 :
82
- assert cp_action .n0 == min (ics )
83
- if len (data ) > 0 :
84
- assert cp_action .n0 == min (data )
71
+ model_n = cp_action .n1
72
+ model_steps += cp_action .n1 - cp_action .n0
85
73
86
74
@action .register (Reverse )
87
75
def action_reverse (cp_action ):
@@ -93,8 +81,10 @@ def action_reverse(cp_action):
93
81
assert cp_action .n0 == cp_action .n1 - 1
94
82
# Non-linear dependency data for the step is stored
95
83
assert cp_action .n0 in data
84
+
96
85
if cp_action .clear_adj_deps :
97
86
data .clear ()
87
+
98
88
model_r += 1
99
89
100
90
@action .register (Copy )
@@ -107,33 +97,28 @@ def action_copy(cp_action):
107
97
108
98
cp = snapshots [cp_action .n ]
109
99
110
- # No data is currently stored for this step
111
- assert cp_action .n not in ics
112
- assert cp_action .n not in data
113
- # The checkpoint contains either forward restart or non-linear
114
- # dependency data, but not both
115
- assert len (cp [0 ]) == 0 or len (cp [1 ]) == 0
116
- assert len (cp [0 ]) > 0 or len (cp [1 ]) > 0
100
+ # No data is currently stored
101
+ assert len (ics ) == 0
102
+ assert len (data ) == 0
103
+ # The checkpoint contains forward restart data
104
+ assert len (cp [0 ]) > 0 and len (cp [1 ]) == 0
117
105
118
- if len (cp [0 ]) > 0 :
119
- # Loading a forward restart checkpoint:
106
+ # Loading a forward restart checkpoint:
120
107
121
- # The checkpoint data is at least two steps away from the current
122
- # location of the adjoint
123
- assert cp_action .n < n - model_r - 1
124
- ics .update (cp [0 ])
125
- model_n = cp_action .n
108
+ # The checkpoint data is at least two steps away from the current
109
+ # location of the adjoint
110
+ assert cp_action .n < n - model_r - 1
111
+ # The loaded data is deleted iff non-linear dependency data for all
112
+ # remaining steps can be stored
113
+ assert cp_action .n < n - model_r - 1 - (s - len (snapshots ) + 1 )
126
114
127
- if len ( cp [ 1 ]) > 0 :
128
- # Loading a non-linear dependency data checkpoint:
115
+ assert cp_action . to_storage == StorageType . WORK
116
+ ics . update ( cp [ 0 ])
129
117
130
- # The checkpoint data is exactly one step away from the current
131
- # location of the adjoint
132
- assert cp_action .n == n - model_r - 1
118
+ model_n = cp_action .n
133
119
134
- data .clear ()
135
- data .update (cp [1 ])
136
- model_n = None
120
+ # Can advance the forward to the current location of the adjoint
121
+ assert ics .issuperset (range (model_n , n - model_r ))
137
122
138
123
@action .register (Move )
139
124
def action_move (cp_action ):
@@ -142,30 +127,46 @@ def action_move(cp_action):
142
127
# The checkpoint exists
143
128
assert cp_action .n in snapshots
144
129
assert cp_action .from_storage == StorageType .DISK
145
- cp = snapshots [cp_action .n ]
146
130
147
- # The checkpoint contains forward restart data
131
+ cp = snapshots .pop (cp_action .n )
132
+
133
+ # No data is currently stored
134
+ assert len (ics ) == 0
135
+ assert len (data ) == 0
136
+ # The checkpoint contains either forward restart or non-linear
137
+ # dependency data, but not both
148
138
assert len (cp [0 ]) == 0 or len (cp [1 ]) == 0
149
139
assert len (cp [0 ]) > 0 or len (cp [1 ]) > 0
150
140
151
141
if len (cp [0 ]) > 0 :
142
+ # Loading a forward restart checkpoint:
143
+
152
144
# The checkpoint data is at least two steps away from the current
153
145
# location of the adjoint
154
146
assert cp_action .n < n - model_r - 1
147
+ # The loaded data is deleted iff non-linear dependency data for all
148
+ # remaining steps can be stored
149
+ assert cp_action .n >= n - model_r - 1 - (s - len (snapshots ) + 1 )
150
+
151
+ assert cp_action .to_storage == StorageType .WORK
155
152
ics .update (cp [0 ])
153
+
156
154
model_n = cp_action .n
157
155
156
+ # Can advance the forward to the current location of the adjoint
157
+ assert ics .issuperset (range (model_n , n - model_r ))
158
+
158
159
if len (cp [1 ]) > 0 :
159
160
# Loading a non-linear dependency data checkpoint:
161
+
160
162
# The checkpoint data is exactly one step away from the current
161
163
# location of the adjoint
162
164
assert cp_action .n == n - model_r - 1
163
165
164
- data . clear ()
166
+ assert cp_action . to_storage == StorageType . WORK
165
167
data .update (cp [1 ])
166
- model_n = None
167
168
168
- del snapshots [ cp_action . n ]
169
+ model_n = None
169
170
170
171
@action .register (EndForward )
171
172
def action_end_forward (cp_action ):
@@ -176,6 +177,8 @@ def action_end_forward(cp_action):
176
177
def action_end_reverse (cp_action ):
177
178
# The correct number of adjoint steps has been taken
178
179
assert model_r == n
180
+ # The schedule has concluded
181
+ assert cp_schedule .is_exhausted
179
182
180
183
for s in S :
181
184
print (f"{ n = :d} { s = :d} " )
@@ -184,9 +187,7 @@ def action_end_reverse(cp_action):
184
187
model_r = 0
185
188
model_steps = 0
186
189
187
- store_ics = False
188
190
ics = set ()
189
- store_data = False
190
191
data = set ()
191
192
192
193
snapshots = {}
@@ -199,24 +200,20 @@ def action_end_reverse(cp_action):
199
200
200
201
for _ , cp_action in enumerate (cp_schedule ):
201
202
action (cp_action )
202
- cp_action_lists .append (cp_action )
203
203
# The schedule state is consistent with both the forward and
204
204
# adjoint
205
205
assert model_n is None or model_n == cp_schedule .n
206
206
assert model_r == cp_schedule .r
207
+ assert cp_schedule .max_n == n
207
208
208
- # Either no data is being stored, or exactly one of forward restart
209
- # or non-linear dependency data is being stored
210
- assert not store_ics or not store_data
209
+ # Either no data is stored, or exactly one of forward restart or
210
+ # non-linear dependency data is stored
211
211
assert len (ics ) == 0 or len (data ) == 0
212
212
# Non-linear dependency data is stored for at most one step
213
213
assert len (data ) <= 1
214
214
# Checkpoint storage limits are not exceeded
215
215
assert len (snapshots ) <= s
216
216
217
- if isinstance (cp_action , EndReverse ):
218
- break
219
-
220
217
# The correct total number of forward steps has been taken
221
218
assert model_steps == optimal_steps_mixed (n , s )
222
219
assert model_steps == mixed_step_memoization (n , s )[2 ]
0 commit comments