Skip to content

Commit efa6d66

Browse files
committed
fix(autograd): include traced keys in HDF5 hash input
1 parent be601f1 commit efa6d66

File tree

6 files changed

+249
-59
lines changed

6 files changed

+249
-59
lines changed

tests/test_components/autograd/test_autograd.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import autograd as ag
1212
import autograd.numpy as anp
13+
import h5py
1314
import matplotlib.pylab as plt
1415
import numpy as np
1516
import numpy.testing as npt
@@ -25,12 +26,14 @@
2526
MINIMUM_SPACING_FRACTION,
2627
)
2728
from tidy3d.components.autograd.derivative_utils import DerivativeInfo
29+
from tidy3d.components.autograd.field_map import FieldMap
2830
from tidy3d.components.autograd.utils import is_tidy_box
31+
from tidy3d.components.base import TRACED_FIELD_KEYS_ATTR
2932
from tidy3d.components.data.data_array import DataArray
3033
from tidy3d.exceptions import AdjointError
3134
from tidy3d.plugins.polyslab import ComplexPolySlab
3235
from tidy3d.web import run, run_async
33-
from tidy3d.web.api.autograd.utils import FieldMap
36+
from tidy3d.web.api.autograd import autograd as autograd_module
3437

3538
from ...utils import SIM_FULL, AssertLogLevel, run_emulated, tracer_arr
3639

@@ -1174,6 +1177,124 @@ def objective(*params):
11741177
ag.grad(objective)(params0)
11751178

11761179

1180+
def test_sim_hash_changes_with_traced_keys():
1181+
"""Ensure the model hash accounts for autograd traced paths."""
1182+
1183+
sim_traced = SIM_FULL.copy()
1184+
original_field_map = sim_traced._strip_traced_fields()
1185+
1186+
structures = list(sim_traced.structures)
1187+
structures[0] = structures[0].to_static()
1188+
sim_modified = sim_traced.updated_copy(structures=tuple(structures))
1189+
1190+
modified_field_map = sim_modified._strip_traced_fields()
1191+
assert original_field_map != modified_field_map
1192+
assert sim_traced._hash_self() != sim_modified._hash_self()
1193+
1194+
1195+
def test_sim_hdf5_records_traced_keys(tmp_path):
1196+
"""HDF5 exports should include traced-key metadata for caching."""
1197+
1198+
sim_traced = SIM_FULL.copy()
1199+
expected_payload = sim_traced._serialized_traced_field_keys()
1200+
assert expected_payload, "simulation fixture must yield traced keys"
1201+
1202+
sim_traced.attrs.pop(TRACED_FIELD_KEYS_ATTR, None)
1203+
1204+
export_path = tmp_path / "sim_traced.hdf5"
1205+
sim_traced.to_hdf5(str(export_path))
1206+
1207+
with h5py.File(export_path, "r") as handle:
1208+
assert TRACED_FIELD_KEYS_ATTR in handle.attrs
1209+
assert handle.attrs[TRACED_FIELD_KEYS_ATTR] == expected_payload
1210+
1211+
static_export = tmp_path / "sim_traced_static.hdf5"
1212+
sim_traced.attrs[TRACED_FIELD_KEYS_ATTR] = expected_payload
1213+
sim_static = sim_traced.to_static()
1214+
sim_static.to_hdf5(str(static_export))
1215+
1216+
with h5py.File(static_export, "r") as handle:
1217+
assert TRACED_FIELD_KEYS_ATTR in handle.attrs
1218+
assert handle.attrs[TRACED_FIELD_KEYS_ATTR] == expected_payload
1219+
1220+
1221+
def test_web_run_duplicate_simulations(monkeypatch):
1222+
"""Repeated simulation objects should reuse cached data without hash mismatches."""
1223+
1224+
sim = SIM_FULL.copy()
1225+
sim.attrs.pop(TRACED_FIELD_KEYS_ATTR, None)
1226+
1227+
copy_calls = {"count": 0}
1228+
1229+
class DummyData:
1230+
def __init__(self, label: str):
1231+
self.label = label
1232+
1233+
def copy(self):
1234+
copy_calls["count"] += 1
1235+
return DummyData(f"{self.label}_copy{copy_calls['count']}")
1236+
1237+
dummy = DummyData("root")
1238+
1239+
def fake_run_autograd(*args, **kwargs):
1240+
return dummy
1241+
1242+
monkeypatch.setattr("tidy3d.web.api.run.run_autograd", fake_run_autograd)
1243+
1244+
results = web.run([sim, sim])
1245+
1246+
assert isinstance(results, list)
1247+
assert len(results) == 2
1248+
assert results[0] is dummy
1249+
assert results[1] is not dummy
1250+
assert copy_calls["count"] == 1
1251+
1252+
1253+
def test_autograd_run_does_not_mutate_input_attrs(monkeypatch):
1254+
"""Autograd run should attach traced metadata only to the exported static copy."""
1255+
1256+
sim = SIM_FULL.copy()
1257+
sim.attrs.pop(TRACED_FIELD_KEYS_ATTR, None)
1258+
payload = sim._serialized_traced_field_keys()
1259+
assert payload
1260+
1261+
captured: dict[str, typing.Any] = {}
1262+
1263+
def fake_run_primitive(
1264+
sim_fields,
1265+
sim_original,
1266+
task_name,
1267+
aux_data,
1268+
local_gradient,
1269+
max_num_adjoint_per_fwd,
1270+
**run_kwargs,
1271+
):
1272+
captured["sim_original"] = sim_original
1273+
captured["payload"] = sim_original.attrs.get(TRACED_FIELD_KEYS_ATTR)
1274+
captured["sim_fields"] = sim_fields
1275+
captured["aux_data"] = aux_data
1276+
return sim_fields
1277+
1278+
def fake_postprocess_run(traced_fields_data, aux_data):
1279+
captured["postprocess_data"] = traced_fields_data
1280+
captured["postprocess_aux"] = aux_data
1281+
return "sentinel"
1282+
1283+
monkeypatch.setattr(autograd_module, "_run_primitive", fake_run_primitive)
1284+
monkeypatch.setattr(autograd_module, "postprocess_run", fake_postprocess_run)
1285+
1286+
result = autograd_module._run(simulation=sim, task_name="dummy")
1287+
1288+
assert result == "sentinel"
1289+
assert sim.attrs.get(TRACED_FIELD_KEYS_ATTR) is None
1290+
assert captured["payload"] == payload
1291+
assert captured["sim_original"] is not sim
1292+
assert captured["sim_original"].attrs.get(TRACED_FIELD_KEYS_ATTR) == payload
1293+
assert captured["postprocess_data"] == captured["sim_fields"]
1294+
assert captured["postprocess_aux"] is captured["aux_data"]
1295+
assert captured["postprocess_aux"] == {}
1296+
1297+
11771298
def test_sim_traced_override_structures():
11781299
"""Make sure that sims with traced override structures are handled properly."""
11791300

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Typed containers for autograd traced field metadata."""
2+
3+
from __future__ import annotations
4+
5+
import json
6+
from typing import Any, Callable
7+
8+
import pydantic.v1 as pydantic
9+
10+
from tidy3d.components.autograd.types import AutogradFieldMap, dict_ag
11+
from tidy3d.components.base import Tidy3dBaseModel
12+
from tidy3d.components.types import ArrayLike, tidycomplex
13+
14+
15+
class Tracer(Tidy3dBaseModel):
16+
"""Representation of a single traced element within a model."""
17+
18+
path: tuple[Any, ...] = pydantic.Field(
19+
...,
20+
title="Path to the traced object in the model dictionary.",
21+
)
22+
data: float | tidycomplex | ArrayLike = pydantic.Field(..., title="Tracing data")
23+
24+
25+
class FieldMap(Tidy3dBaseModel):
26+
"""Collection of traced elements."""
27+
28+
tracers: tuple[Tracer, ...] = pydantic.Field(
29+
...,
30+
title="Collection of Tracers.",
31+
)
32+
33+
@property
34+
def to_autograd_field_map(self) -> AutogradFieldMap:
35+
"""Convert to ``AutogradFieldMap`` autograd dictionary."""
36+
return dict_ag({tracer.path: tracer.data for tracer in self.tracers})
37+
38+
@classmethod
39+
def from_autograd_field_map(cls, autograd_field_map: AutogradFieldMap) -> FieldMap:
40+
"""Initialize from an ``AutogradFieldMap`` autograd dictionary."""
41+
tracers = []
42+
for path, data in autograd_field_map.items():
43+
tracers.append(Tracer(path=path, data=data))
44+
return cls(tracers=tuple(tracers))
45+
46+
47+
def _encoded_path(path: tuple[Any, ...]) -> str:
48+
"""Return a stable JSON representation for a traced path."""
49+
return json.dumps(list(path), separators=(",", ":"), ensure_ascii=True)
50+
51+
52+
class TracerKeys(Tidy3dBaseModel):
53+
"""Collection of traced field paths."""
54+
55+
keys: tuple[tuple[Any, ...], ...] = pydantic.Field(
56+
...,
57+
title="Collection of tracer keys.",
58+
)
59+
60+
def encoded_keys(self) -> list[str]:
61+
"""Return the JSON-encoded representation of keys."""
62+
return [_encoded_path(path) for path in self.keys]
63+
64+
@classmethod
65+
def from_field_mapping(
66+
cls,
67+
field_mapping: AutogradFieldMap,
68+
*,
69+
sort_key: Callable[[tuple[Any, ...]], str] | None = None,
70+
) -> TracerKeys:
71+
"""Construct keys from an autograd field mapping."""
72+
if sort_key is None:
73+
sort_key = _encoded_path
74+
75+
sorted_paths = tuple(sorted(field_mapping.keys(), key=sort_key))
76+
return cls(keys=sorted_paths)

tidy3d/components/base.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
# If json string is larger than ``MAX_STRING_LENGTH``, split the string when storing in hdf5
4040
MAX_STRING_LENGTH = 1_000_000_000
4141
FORBID_SPECIAL_CHARACTERS = ["/"]
42+
TRACED_FIELD_KEYS_ATTR = "__tidy3d_traced_field_keys__"
4243

4344

4445
def cache(prop):
@@ -524,7 +525,8 @@ def to_json(self, fname: str) -> None:
524525
-------
525526
>>> simulation.to_json(fname='folder/sim.json') # doctest: +SKIP
526527
"""
527-
json_string = self._json(indent=INDENT_JSON_FILE)
528+
export_model = self.to_static()
529+
json_string = export_model._json(indent=INDENT_JSON_FILE)
528530
self._warn_if_contains_data(json_string)
529531
with open(fname, "w", encoding="utf-8") as file_handle:
530532
file_handle.write(json_string)
@@ -586,7 +588,8 @@ def to_yaml(self, fname: str) -> None:
586588
-------
587589
>>> simulation.to_yaml(fname='folder/sim.yaml') # doctest: +SKIP
588590
"""
589-
json_string = self._json_string
591+
export_model = self.to_static()
592+
json_string = export_model._json()
590593
self._warn_if_contains_data(json_string)
591594
model_dict = json.loads(json_string)
592595
with open(fname, "w+", encoding="utf-8") as file_handle:
@@ -792,8 +795,15 @@ def to_hdf5(
792795
>>> simulation.to_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP
793796
"""
794797

798+
export_model = self.to_static()
799+
traced_keys_payload = export_model.attrs.get(TRACED_FIELD_KEYS_ATTR)
800+
801+
if traced_keys_payload is None:
802+
traced_keys_payload = self.attrs.get(TRACED_FIELD_KEYS_ATTR)
803+
if traced_keys_payload is None:
804+
traced_keys_payload = self._serialized_traced_field_keys()
795805
with h5py.File(fname, "w") as f_handle:
796-
json_str = self._json_string
806+
json_str = export_model._json()
797807
for ind in range(ceil(len(json_str) / MAX_STRING_LENGTH)):
798808
ind_start = int(ind * MAX_STRING_LENGTH)
799809
ind_stop = min(int(ind + 1) * MAX_STRING_LENGTH, len(json_str))
@@ -816,14 +826,16 @@ def add_data_to_file(data_dict: dict, group_path: str = "") -> None:
816826

817827
# if a tuple, assign each element a unique key
818828
if isinstance(value, (list, tuple)):
819-
value_dict = self.tuple_to_dict(tuple_values=value)
829+
value_dict = export_model.tuple_to_dict(tuple_values=value)
820830
add_data_to_file(data_dict=value_dict, group_path=subpath)
821831

822832
# if a dict, recurse
823833
elif isinstance(value, dict):
824834
add_data_to_file(data_dict=value, group_path=subpath)
825835

826-
add_data_to_file(data_dict=self.dict())
836+
add_data_to_file(data_dict=export_model.dict())
837+
if traced_keys_payload:
838+
f_handle.attrs[TRACED_FIELD_KEYS_ATTR] = traced_keys_payload
827839

828840
@classmethod
829841
def dict_from_hdf5_gz(
@@ -1101,6 +1113,22 @@ def insert_value(x, path: tuple[str, ...], sub_dict: dict):
11011113

11021114
return self.parse_obj(self_dict)
11031115

1116+
def _serialized_traced_field_keys(
1117+
self, field_mapping: AutogradFieldMap | None = None
1118+
) -> Optional[str]:
1119+
"""Return a serialized, order-independent representation of traced field paths."""
1120+
1121+
if field_mapping is None:
1122+
field_mapping = self._strip_traced_fields()
1123+
if not field_mapping:
1124+
return None
1125+
1126+
# TODO: remove this deferred import once TracerKeys is decoupled from Tidy3dBaseModel.
1127+
from tidy3d.components.autograd.field_map import TracerKeys
1128+
1129+
tracer_keys = TracerKeys.from_field_mapping(field_mapping)
1130+
return tracer_keys.json(separators=(",", ":"), ensure_ascii=True)
1131+
11041132
def to_static(self) -> Tidy3dBaseModel:
11051133
"""Version of object with all autograd-traced fields removed."""
11061134

tidy3d/web/api/autograd/autograd.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
MAX_NUM_ADJOINT_PER_FWD,
1515
MAX_NUM_TRACED_STRUCTURES,
1616
)
17+
from tidy3d.components.base import TRACED_FIELD_KEYS_ATTR
1718
from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType
1819
from tidy3d.exceptions import AdjointError
1920
from tidy3d.web.api.asynchronous import DEFAULT_DATA_DIR
@@ -409,9 +410,14 @@ def _run(
409410
aux_data = {}
410411

411412
# run our custom @primitive, passing the traced fields first to register with autograd
413+
sim_static = simulation.to_static()
414+
traced_keys_payload = simulation._serialized_traced_field_keys()
415+
if traced_keys_payload:
416+
sim_static.attrs[TRACED_FIELD_KEYS_ATTR] = traced_keys_payload
417+
412418
traced_fields_data = _run_primitive(
413419
traced_fields_sim, # if you pass as a kwarg it will not trace :/
414-
sim_original=simulation.to_static(),
420+
sim_original=sim_static,
415421
task_name=task_name,
416422
aux_data=aux_data,
417423
local_gradient=local_gradient,
@@ -433,16 +439,22 @@ def _run_async(
433439
task_names = simulations.keys()
434440

435441
traced_fields_sim_dict = {}
442+
sims_original = {}
436443
for task_name in task_names:
437-
traced_fields_sim_dict[task_name] = setup_run(simulation=simulations[task_name])
444+
simulation = simulations[task_name]
445+
traced_fields = setup_run(simulation=simulation)
446+
traced_fields_sim_dict[task_name] = traced_fields
447+
sim_static = simulation.to_static()
448+
if traced_fields:
449+
traced_keys_payload = simulation._serialized_traced_field_keys()
450+
if traced_keys_payload:
451+
sim_static.attrs[TRACED_FIELD_KEYS_ATTR] = traced_keys_payload
452+
sims_original[task_name] = sim_static
438453
traced_fields_sim_dict = dict_ag(traced_fields_sim_dict)
439454

440455
# TODO: shortcut primitive running for any items with no tracers?
441456

442457
aux_data_dict = {task_name: {} for task_name in task_names}
443-
sims_original = {
444-
task_name: simulation.to_static() for task_name, simulation in simulations.items()
445-
}
446458
traced_fields_data_dict = _run_async_primitive(
447459
traced_fields_sim_dict, # if you pass as a kwarg it will not trace :/
448460
sims_original=sims_original,

tidy3d/web/api/autograd/io_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import tempfile
55

66
import tidy3d as td
7+
from tidy3d.components.autograd.field_map import FieldMap, TracerKeys
78
from tidy3d.web.core.s3utils import download_file, upload_file # type: ignore
89

910
from .constants import SIM_FIELDS_KEYS_FILE, SIM_VJP_FILE
10-
from .utils import FieldMap, TracerKeys
1111

1212

1313
def upload_sim_fields_keys(sim_fields_keys: list[tuple], task_id: str, verbose: bool = False):

0 commit comments

Comments
 (0)