Skip to content

Commit 055c881

Browse files
authored
Merge pull request #52 from firedrakeproject/jrmaddison/test_mixed
Update mixed unit tests
2 parents 46c6c51 + e9e1a54 commit 055c881

File tree

1 file changed

+74
-77
lines changed

1 file changed

+74
-77
lines changed

Diff for: tests/test_mixed.py

+74-77
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,75 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3+
34
import functools
45
import pytest
56

6-
from checkpoint_schedules import MixedCheckpointSchedule, Copy, \
7-
Forward, Reverse, EndForward, EndReverse, Move, StorageType
8-
from checkpoint_schedules.mixed import optimal_steps_mixed, \
9-
mixed_step_memoization
7+
from checkpoint_schedules import (
8+
MixedCheckpointSchedule, Copy, Forward, Reverse, EndForward, EndReverse,
9+
Move, StorageType)
10+
from checkpoint_schedules.mixed import (
11+
optimal_steps_mixed, mixed_step_memoization)
1012

1113

1214
@pytest.mark.parametrize("n, S", [
1315
(1, (0,)),
1416
(2, (1,)),
1517
(3, (1, 2)),
16-
(10, tuple(range(1, 10))),
18+
(5, (2,)),
19+
(10, tuple(range(2, 10))),
1720
(100, tuple(range(1, 100))),
1821
(250, tuple(range(25, 250, 25)))
1922
])
2023
def test_mixed(n, S):
21-
cp_action_lists = []
22-
2324
@functools.singledispatch
2425
def action(cp_action):
2526
raise TypeError("Unexpected action")
2627

2728
@action.register(Forward)
2829
def action_forward(cp_action):
2930
nonlocal model_n, model_steps
30-
nonlocal store_ics, store_data
31-
store_ics = cp_action.write_ics
32-
store_data = cp_action.write_adj_deps
31+
3332
ics.clear()
3433
data.clear()
3534

3635
# Start at the current location of the forward
37-
assert model_n is not None and model_n == cp_action.n0
36+
assert model_n is not None and cp_action.n0 == model_n
37+
# Do not advance further than the current location of the adjoint
38+
assert cp_action.n1 <= n - model_r
3839

39-
if store_ics:
40-
assert cp_action.storage == StorageType.DISK
41-
assert cp_schedule.uses_storage_type(StorageType.DISK)
40+
if cp_action.write_ics:
4241
# Advance at least two steps when storing forward restart data
4342
assert cp_action.n1 > cp_action.n0 + 1
4443
# Do not advance further than one step before the current location
4544
# of the adjoint
4645
assert cp_action.n1 < n - model_r
47-
# No data for these steps is stored
48-
assert len(ics.intersection(range(cp_action.n0, cp_action.n1))) == 0 # noqa: E501
4946

50-
if store_data:
47+
assert cp_action.storage == StorageType.DISK
48+
assert cp_schedule.uses_storage_type(StorageType.DISK)
49+
ics.update(range(cp_action.n0, cp_action.n1))
50+
assert cp_action.n0 not in snapshots
51+
snapshots[cp_action.n0] = (set(ics), set(data))
52+
53+
if cp_action.write_adj_deps:
5154
# Advance exactly one step when storing non-linear dependency data
5255
assert cp_action.n1 == cp_action.n0 + 1
5356
# Do not advance further than the current location of the adjoint
5457
assert cp_action.n1 <= n - model_r
55-
# No data for this step is stored
56-
assert len(data.intersection(range(cp_action.n0, cp_action.n1))) == 0 # noqa: E501
57-
58-
model_n = cp_action.n1
59-
model_steps += cp_action.n1 - cp_action.n0
60-
if store_ics:
61-
ics.update(range(cp_action.n0, cp_action.n1))
62-
snapshots[cp_action.n0] = (set(ics), set(data))
6358

64-
if store_data:
6559
data.update(range(cp_action.n0, cp_action.n1))
6660
if cp_action.storage == StorageType.DISK:
61+
assert cp_schedule.uses_storage_type(StorageType.DISK)
62+
assert cp_action.n0 not in snapshots
6763
snapshots[cp_action.n0] = (set(ics), set(data))
64+
else:
65+
assert cp_action.storage == StorageType.WORK
6866

69-
if store_ics or store_data:
70-
# Written data consists of either forward restart or non-linear
71-
# dependency data, but not both
72-
assert len(ics) == 0 or len(data) == 0
73-
assert len(ics) > 0 or len(data) > 0
74-
75-
# Non-linear dependency data is either not stored, or is stored for a
76-
# single step
77-
assert len(data) <= 1
67+
# Stored data consists of either forward restart or non-linear
68+
# dependency data, but not both
69+
assert not cp_action.write_ics or not cp_action.write_adj_deps
7870

79-
# The checkpoint location is associated with the earliest step for
80-
# which data has been stored
81-
if len(ics) > 0:
82-
assert cp_action.n0 == min(ics)
83-
if len(data) > 0:
84-
assert cp_action.n0 == min(data)
71+
model_n = cp_action.n1
72+
model_steps += cp_action.n1 - cp_action.n0
8573

8674
@action.register(Reverse)
8775
def action_reverse(cp_action):
@@ -93,8 +81,10 @@ def action_reverse(cp_action):
9381
assert cp_action.n0 == cp_action.n1 - 1
9482
# Non-linear dependency data for the step is stored
9583
assert cp_action.n0 in data
84+
9685
if cp_action.clear_adj_deps:
9786
data.clear()
87+
9888
model_r += 1
9989

10090
@action.register(Copy)
@@ -107,33 +97,28 @@ def action_copy(cp_action):
10797

10898
cp = snapshots[cp_action.n]
10999

110-
# No data is currently stored for this step
111-
assert cp_action.n not in ics
112-
assert cp_action.n not in data
113-
# The checkpoint contains either forward restart or non-linear
114-
# dependency data, but not both
115-
assert len(cp[0]) == 0 or len(cp[1]) == 0
116-
assert len(cp[0]) > 0 or len(cp[1]) > 0
100+
# No data is currently stored
101+
assert len(ics) == 0
102+
assert len(data) == 0
103+
# The checkpoint contains forward restart data
104+
assert len(cp[0]) > 0 and len(cp[1]) == 0
117105

118-
if len(cp[0]) > 0:
119-
# Loading a forward restart checkpoint:
106+
# Loading a forward restart checkpoint:
120107

121-
# The checkpoint data is at least two steps away from the current
122-
# location of the adjoint
123-
assert cp_action.n < n - model_r - 1
124-
ics.update(cp[0])
125-
model_n = cp_action.n
108+
# The checkpoint data is at least two steps away from the current
109+
# location of the adjoint
110+
assert cp_action.n < n - model_r - 1
111+
# The loaded data is deleted iff non-linear dependency data for all
112+
# remaining steps can be stored
113+
assert cp_action.n < n - model_r - 1 - (s - len(snapshots) + 1)
126114

127-
if len(cp[1]) > 0:
128-
# Loading a non-linear dependency data checkpoint:
115+
assert cp_action.to_storage == StorageType.WORK
116+
ics.update(cp[0])
129117

130-
# The checkpoint data is exactly one step away from the current
131-
# location of the adjoint
132-
assert cp_action.n == n - model_r - 1
118+
model_n = cp_action.n
133119

134-
data.clear()
135-
data.update(cp[1])
136-
model_n = None
120+
# Can advance the forward to the current location of the adjoint
121+
assert ics.issuperset(range(model_n, n - model_r))
137122

138123
@action.register(Move)
139124
def action_move(cp_action):
@@ -142,30 +127,46 @@ def action_move(cp_action):
142127
# The checkpoint exists
143128
assert cp_action.n in snapshots
144129
assert cp_action.from_storage == StorageType.DISK
145-
cp = snapshots[cp_action.n]
146130

147-
# The checkpoint contains forward restart data
131+
cp = snapshots.pop(cp_action.n)
132+
133+
# No data is currently stored
134+
assert len(ics) == 0
135+
assert len(data) == 0
136+
# The checkpoint contains either forward restart or non-linear
137+
# dependency data, but not both
148138
assert len(cp[0]) == 0 or len(cp[1]) == 0
149139
assert len(cp[0]) > 0 or len(cp[1]) > 0
150140

151141
if len(cp[0]) > 0:
142+
# Loading a forward restart checkpoint:
143+
152144
# The checkpoint data is at least two steps away from the current
153145
# location of the adjoint
154146
assert cp_action.n < n - model_r - 1
147+
# The loaded data is deleted iff non-linear dependency data for all
148+
# remaining steps can be stored
149+
assert cp_action.n >= n - model_r - 1 - (s - len(snapshots) + 1)
150+
151+
assert cp_action.to_storage == StorageType.WORK
155152
ics.update(cp[0])
153+
156154
model_n = cp_action.n
157155

156+
# Can advance the forward to the current location of the adjoint
157+
assert ics.issuperset(range(model_n, n - model_r))
158+
158159
if len(cp[1]) > 0:
159160
# Loading a non-linear dependency data checkpoint:
161+
160162
# The checkpoint data is exactly one step away from the current
161163
# location of the adjoint
162164
assert cp_action.n == n - model_r - 1
163165

164-
data.clear()
166+
assert cp_action.to_storage == StorageType.WORK
165167
data.update(cp[1])
166-
model_n = None
167168

168-
del snapshots[cp_action.n]
169+
model_n = None
169170

170171
@action.register(EndForward)
171172
def action_end_forward(cp_action):
@@ -176,6 +177,8 @@ def action_end_forward(cp_action):
176177
def action_end_reverse(cp_action):
177178
# The correct number of adjoint steps has been taken
178179
assert model_r == n
180+
# The schedule has concluded
181+
assert cp_schedule.is_exhausted
179182

180183
for s in S:
181184
print(f"{n=:d} {s=:d}")
@@ -184,9 +187,7 @@ def action_end_reverse(cp_action):
184187
model_r = 0
185188
model_steps = 0
186189

187-
store_ics = False
188190
ics = set()
189-
store_data = False
190191
data = set()
191192

192193
snapshots = {}
@@ -199,24 +200,20 @@ def action_end_reverse(cp_action):
199200

200201
for _, cp_action in enumerate(cp_schedule):
201202
action(cp_action)
202-
cp_action_lists.append(cp_action)
203203
# The schedule state is consistent with both the forward and
204204
# adjoint
205205
assert model_n is None or model_n == cp_schedule.n
206206
assert model_r == cp_schedule.r
207+
assert cp_schedule.max_n == n
207208

208-
# Either no data is being stored, or exactly one of forward restart
209-
# or non-linear dependency data is being stored
210-
assert not store_ics or not store_data
209+
# Either no data is stored, or exactly one of forward restart or
210+
# non-linear dependency data is stored
211211
assert len(ics) == 0 or len(data) == 0
212212
# Non-linear dependency data is stored for at most one step
213213
assert len(data) <= 1
214214
# Checkpoint storage limits are not exceeded
215215
assert len(snapshots) <= s
216216

217-
if isinstance(cp_action, EndReverse):
218-
break
219-
220217
# The correct total number of forward steps has been taken
221218
assert model_steps == optimal_steps_mixed(n, s)
222219
assert model_steps == mixed_step_memoization(n, s)[2]

0 commit comments

Comments
 (0)