|
10 | 10 |
|
11 | 11 | import autograd as ag |
12 | 12 | import autograd.numpy as anp |
| 13 | +import h5py |
13 | 14 | import matplotlib.pylab as plt |
14 | 15 | import numpy as np |
15 | 16 | import numpy.testing as npt |
|
25 | 26 | MINIMUM_SPACING_FRACTION, |
26 | 27 | ) |
27 | 28 | from tidy3d.components.autograd.derivative_utils import DerivativeInfo |
| 29 | +from tidy3d.components.autograd.field_map import FieldMap |
28 | 30 | from tidy3d.components.autograd.utils import is_tidy_box |
| 31 | +from tidy3d.components.base import TRACED_FIELD_KEYS_ATTR |
29 | 32 | from tidy3d.components.data.data_array import DataArray |
30 | 33 | from tidy3d.exceptions import AdjointError |
31 | 34 | from tidy3d.plugins.polyslab import ComplexPolySlab |
32 | 35 | from tidy3d.web import run, run_async |
33 | | -from tidy3d.web.api.autograd.utils import FieldMap |
| 36 | +from tidy3d.web.api.autograd import autograd as autograd_module |
34 | 37 |
|
35 | 38 | from ...utils import SIM_FULL, AssertLogLevel, run_emulated, tracer_arr |
36 | 39 |
|
@@ -1174,6 +1177,124 @@ def objective(*params): |
1174 | 1177 | ag.grad(objective)(params0) |
1175 | 1178 |
|
1176 | 1179 |
|
| 1180 | +def test_sim_hash_changes_with_traced_keys(): |
| 1181 | + """Ensure the model hash accounts for autograd traced paths.""" |
| 1182 | + |
| 1183 | + sim_traced = SIM_FULL.copy() |
| 1184 | + original_field_map = sim_traced._strip_traced_fields() |
| 1185 | + |
| 1186 | + structures = list(sim_traced.structures) |
| 1187 | + structures[0] = structures[0].to_static() |
| 1188 | + sim_modified = sim_traced.updated_copy(structures=tuple(structures)) |
| 1189 | + |
| 1190 | + modified_field_map = sim_modified._strip_traced_fields() |
| 1191 | + assert original_field_map != modified_field_map |
| 1192 | + assert sim_traced._hash_self() != sim_modified._hash_self() |
| 1193 | + |
| 1194 | + |
| 1195 | +def test_sim_hdf5_records_traced_keys(tmp_path): |
| 1196 | + """HDF5 exports should include traced-key metadata for caching.""" |
| 1197 | + |
| 1198 | + sim_traced = SIM_FULL.copy() |
| 1199 | + expected_payload = sim_traced._serialized_traced_field_keys() |
| 1200 | + assert expected_payload, "simulation fixture must yield traced keys" |
| 1201 | + |
| 1202 | + sim_traced.attrs.pop(TRACED_FIELD_KEYS_ATTR, None) |
| 1203 | + |
| 1204 | + export_path = tmp_path / "sim_traced.hdf5" |
| 1205 | + sim_traced.to_hdf5(str(export_path)) |
| 1206 | + |
| 1207 | + with h5py.File(export_path, "r") as handle: |
| 1208 | + assert TRACED_FIELD_KEYS_ATTR in handle.attrs |
| 1209 | + assert handle.attrs[TRACED_FIELD_KEYS_ATTR] == expected_payload |
| 1210 | + |
| 1211 | + static_export = tmp_path / "sim_traced_static.hdf5" |
| 1212 | + sim_traced.attrs[TRACED_FIELD_KEYS_ATTR] = expected_payload |
| 1213 | + sim_static = sim_traced.to_static() |
| 1214 | + sim_static.to_hdf5(str(static_export)) |
| 1215 | + |
| 1216 | + with h5py.File(static_export, "r") as handle: |
| 1217 | + assert TRACED_FIELD_KEYS_ATTR in handle.attrs |
| 1218 | + assert handle.attrs[TRACED_FIELD_KEYS_ATTR] == expected_payload |
| 1219 | + |
| 1220 | + |
| 1221 | +def test_web_run_duplicate_simulations(monkeypatch): |
| 1222 | + """Repeated simulation objects should reuse cached data without hash mismatches.""" |
| 1223 | + |
| 1224 | + sim = SIM_FULL.copy() |
| 1225 | + sim.attrs.pop(TRACED_FIELD_KEYS_ATTR, None) |
| 1226 | + |
| 1227 | + copy_calls = {"count": 0} |
| 1228 | + |
| 1229 | + class DummyData: |
| 1230 | + def __init__(self, label: str): |
| 1231 | + self.label = label |
| 1232 | + |
| 1233 | + def copy(self): |
| 1234 | + copy_calls["count"] += 1 |
| 1235 | + return DummyData(f"{self.label}_copy{copy_calls['count']}") |
| 1236 | + |
| 1237 | + dummy = DummyData("root") |
| 1238 | + |
| 1239 | + def fake_run_autograd(*args, **kwargs): |
| 1240 | + return dummy |
| 1241 | + |
| 1242 | + monkeypatch.setattr("tidy3d.web.api.run.run_autograd", fake_run_autograd) |
| 1243 | + |
| 1244 | + results = web.run([sim, sim]) |
| 1245 | + |
| 1246 | + assert isinstance(results, list) |
| 1247 | + assert len(results) == 2 |
| 1248 | + assert results[0] is dummy |
| 1249 | + assert results[1] is not dummy |
| 1250 | + assert copy_calls["count"] == 1 |
| 1251 | + |
| 1252 | + |
| 1253 | +def test_autograd_run_does_not_mutate_input_attrs(monkeypatch): |
| 1254 | + """Autograd run should attach traced metadata only to the exported static copy.""" |
| 1255 | + |
| 1256 | + sim = SIM_FULL.copy() |
| 1257 | + sim.attrs.pop(TRACED_FIELD_KEYS_ATTR, None) |
| 1258 | + payload = sim._serialized_traced_field_keys() |
| 1259 | + assert payload |
| 1260 | + |
| 1261 | + captured: dict[str, typing.Any] = {} |
| 1262 | + |
| 1263 | + def fake_run_primitive( |
| 1264 | + sim_fields, |
| 1265 | + sim_original, |
| 1266 | + task_name, |
| 1267 | + aux_data, |
| 1268 | + local_gradient, |
| 1269 | + max_num_adjoint_per_fwd, |
| 1270 | + **run_kwargs, |
| 1271 | + ): |
| 1272 | + captured["sim_original"] = sim_original |
| 1273 | + captured["payload"] = sim_original.attrs.get(TRACED_FIELD_KEYS_ATTR) |
| 1274 | + captured["sim_fields"] = sim_fields |
| 1275 | + captured["aux_data"] = aux_data |
| 1276 | + return sim_fields |
| 1277 | + |
| 1278 | + def fake_postprocess_run(traced_fields_data, aux_data): |
| 1279 | + captured["postprocess_data"] = traced_fields_data |
| 1280 | + captured["postprocess_aux"] = aux_data |
| 1281 | + return "sentinel" |
| 1282 | + |
| 1283 | + monkeypatch.setattr(autograd_module, "_run_primitive", fake_run_primitive) |
| 1284 | + monkeypatch.setattr(autograd_module, "postprocess_run", fake_postprocess_run) |
| 1285 | + |
| 1286 | + result = autograd_module._run(simulation=sim, task_name="dummy") |
| 1287 | + |
| 1288 | + assert result == "sentinel" |
| 1289 | + assert sim.attrs.get(TRACED_FIELD_KEYS_ATTR) is None |
| 1290 | + assert captured["payload"] == payload |
| 1291 | + assert captured["sim_original"] is not sim |
| 1292 | + assert captured["sim_original"].attrs.get(TRACED_FIELD_KEYS_ATTR) == payload |
| 1293 | + assert captured["postprocess_data"] == captured["sim_fields"] |
| 1294 | + assert captured["postprocess_aux"] is captured["aux_data"] |
| 1295 | + assert captured["postprocess_aux"] == {} |
| 1296 | + |
| 1297 | + |
1177 | 1298 | def test_sim_traced_override_structures(): |
1178 | 1299 | """Make sure that sims with traced override structures are handled properly.""" |
1179 | 1300 |
|
|
0 commit comments