Skip to content

Commit 88a8d75

Browse files
fix(tidy3d): FXC-3948-fix-caching-for-autograd
1 parent 8da1ce8 commit 88a8d75

File tree

7 files changed

+188
-71
lines changed

7 files changed

+188
-71
lines changed

tests/test_web/test_local_cache.py

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,23 @@
77
from pathlib import Path
88
from types import SimpleNamespace
99

10+
import autograd as ag
1011
import pytest
12+
import xarray as xr
13+
from autograd.core import defvjp
1114
from rich.console import Console
1215

1316
import tidy3d as td
1417
from tests.test_components.autograd.test_autograd import ALL_KEY, get_functions, params0
1518
from tests.test_web.test_webapi_mode import make_mode_sim
1619
from tidy3d import config
20+
from tidy3d.components.autograd.field_map import FieldMap
1721
from tidy3d.config import get_manager
1822
from tidy3d.web import Job, common, run, run_async
1923
from tidy3d.web.api import webapi as web
24+
from tidy3d.web.api.autograd import autograd, engine, io_utils
25+
from tidy3d.web.api.autograd.autograd import run as run_autograd
26+
from tidy3d.web.api.autograd.constants import SIM_VJP_FILE
2027
from tidy3d.web.api.container import Batch, WebContainer
2128
from tidy3d.web.api.webapi import load_simulation_if_cached
2229
from tidy3d.web.cache import CACHE_ARTIFACT_NAME, clear, get_cache_entry_dir, resolve_local_cache
@@ -39,6 +46,24 @@ class _FakeStubData:
3946
def __init__(self, simulation: td.Simulation):
4047
self.simulation = simulation
4148

49+
def __getitem__(self, key):
50+
if key == "mode":
51+
params = self.simulation.attrs["params_autograd"]
52+
return SimpleNamespace(
53+
amps=xr.DataArray(params, dims=["x"], coords={"x": list(range(len(params)))})
54+
)
55+
56+
def _strip_traced_fields(self, *args, **kwargs):
57+
"""Fake _strip_traced_fields: return minimal valid autograd-style mapping."""
58+
return {"params": self.simulation.attrs["params"]}
59+
60+
def _insert_traced_fields(self, field_mapping, *args, **kwargs):
61+
self.simulation.attrs["params_autograd"] = field_mapping["params"]
62+
return self
63+
64+
def _make_adjoint_sims(self, **kwargs):
65+
return [self.simulation.updated_copy(run_time=self.simulation.run_time * 2)]
66+
4267

4368
@pytest.fixture
4469
def basic_simulation():
@@ -128,25 +153,50 @@ def _fake__check_folder(*args, **kwargs):
128153
def _fake_status(self):
129154
return "success"
130155

156+
def _fake_download_file(resource_id, remote_filename, to_file=None, **kwargs):
157+
# Only count this download if it's the adjoint/VJP file
158+
if str(remote_filename) == SIM_VJP_FILE:
159+
counters["download"] += 1
160+
161+
def _fake_from_file(*args, **kwargs):
162+
field_map = FieldMap(tracers=())
163+
return field_map
164+
165+
monkeypatch.setattr(io_utils, "download_file", _fake_download_file)
166+
monkeypatch.setattr(autograd, "postprocess_fwd", _fake_from_file)
167+
monkeypatch.setattr(FieldMap, "from_file", _fake_from_file)
131168
monkeypatch.setattr(WebContainer, "_check_folder", _fake__check_folder)
132169
monkeypatch.setattr(web, "upload", _fake_upload)
133170
monkeypatch.setattr(web, "start", _fake_start)
134171
monkeypatch.setattr(web, "monitor", _fake_monitor)
135172
monkeypatch.setattr(web, "download", _fake_download)
136173
monkeypatch.setattr(web, "estimate_cost", lambda *args, **kwargs: 0.0)
137174
monkeypatch.setattr(Job, "status", property(_fake_status))
175+
monkeypatch.setattr(engine, "upload_sim_fields_keys", lambda *args, **kwargs: None)
138176
monkeypatch.setattr(
139177
web,
140178
"get_info",
141179
lambda task_id, verbose=True: type(
142180
"_Info", (), {"solverVersion": "solver-1", "taskType": "FDTD"}
143181
)(),
144182
)
183+
monkeypatch.setattr(
184+
io_utils,
185+
"get_info",
186+
lambda task_id, verbose=True: type(
187+
"_Info", (), {"solverVersion": "solver-1", "taskType": "FDTD"}
188+
)(),
189+
)
145190
monkeypatch.setattr(
146191
web, "load_simulation", lambda task_id, *args, **kwargs: TASK_TO_SIM[task_id]
147192
)
193+
monkeypatch.setattr(
194+
io_utils, "load_simulation", lambda task_id, *args, **kwargs: TASK_TO_SIM[task_id]
195+
)
148196
monkeypatch.setattr(BatchTask, "is_batch", lambda *args, **kwargs: "success")
149-
monkeypatch.setattr(BatchTask, "detail", lambda *args: SimpleNamespace(status="success"))
197+
monkeypatch.setattr(
198+
BatchTask, "detail", lambda *args, **kwargs: SimpleNamespace(status="success")
199+
)
150200
return counters
151201

152202

@@ -371,23 +421,59 @@ def _test_job_run_cache(monkeypatch, basic_simulation, tmp_path):
371421
assert os.path.exists(out2_path)
372422

373423

374-
def _test_autograd_cache(monkeypatch):
424+
def _test_autograd_cache(monkeypatch, request):
375425
counters = _patch_run_pipeline(monkeypatch)
426+
427+
# "Original" rule: the one autograd uses by default
428+
def _orig_make_dict_vjp(ans, keys, vals):
429+
return lambda g: [g[key] for key in keys]
430+
431+
def _zero_make_dict_vjp(ans, keys, vals):
432+
def vjp(g):
433+
# One gradient per entry in `vals`, all zeros, matching shape/dtype
434+
return [ag.numpy.zeros_like(v) for v in vals]
435+
436+
return vjp
437+
438+
# Install our zero-VJP (this is the thing that affects global state)
439+
defvjp(
440+
ag.builtins._make_dict,
441+
_zero_make_dict_vjp,
442+
argnums=(1,), # gradient w.r.t. `vals`
443+
)
444+
445+
# Make sure we restore it after the test
446+
def _restore_make_dict_vjp():
447+
defvjp(
448+
ag.builtins._make_dict,
449+
_orig_make_dict_vjp,
450+
argnums=(1,),
451+
)
452+
453+
request.addfinalizer(_restore_make_dict_vjp)
454+
376455
cache = resolve_local_cache(use_cache=True)
377456
cache.clear()
378457

379458
functions = get_functions(ALL_KEY, "mode")
380459
make_sim = functions["sim"]
381-
sim = make_sim(params0)
382-
web.run(sim)
383-
assert counters["download"] == 1
384-
assert len(cache) == 1
460+
postprocess = functions["postprocess"]
461+
462+
def objective(params):
463+
sim = make_sim(params)
464+
sim.attrs["params"] = params
465+
sim_data = run_autograd(sim)
466+
value = postprocess(sim_data)
467+
return value
468+
469+
ag.value_and_grad(objective)(params0)
470+
assert counters["download"] == 2
471+
assert len(cache) == 2
385472

386473
_reset_counters(counters)
387-
sim = make_sim(params0)
388-
web.run(sim)
389-
assert counters["download"] == 0
390-
assert len(cache) == 1
474+
ag.value_and_grad(objective)(params0)
475+
assert counters["download"] == 1 # download field data
476+
assert len(cache) == 2
391477

392478

393479
def _test_load_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data):
@@ -499,7 +585,9 @@ def _test_env_var_overrides(monkeypatch, tmp_path):
499585
manager._reload()
500586

501587

502-
def test_cache_sequential(monkeypatch, tmp_path, tmp_path_factory, basic_simulation, fake_data):
588+
def test_cache_sequential(
589+
monkeypatch, tmp_path, tmp_path_factory, basic_simulation, fake_data, request
590+
):
503591
"""Run all critical cache tests in sequence to ensure stability."""
504592
monkeypatch.setattr(config.local_cache, "enabled", True)
505593

@@ -514,7 +602,7 @@ def test_cache_sequential(monkeypatch, tmp_path, tmp_path_factory, basic_simulat
514602
_test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation)
515603
_test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path)
516604
_test_job_run_cache(monkeypatch, basic_simulation, tmp_path)
517-
_test_autograd_cache(monkeypatch)
605+
_test_autograd_cache(monkeypatch, request)
518606
_test_configure_cache_roundtrip(monkeypatch, tmp_path)
519607
_test_mode_solver_caching(monkeypatch, tmp_path)
520608
_test_verbosity(monkeypatch, basic_simulation)

tests/test_web/test_webapi.py

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,29 @@
6969
Env.dev.active()
7070

7171

72+
class FakeJob:
73+
def __init__(self, task_id: str, statuses: list[str], events: list[str]):
74+
self.task_id = task_id
75+
self._statuses = statuses
76+
self._idx = 0
77+
self.events = events
78+
79+
@property
80+
def status(self):
81+
status = self._statuses[self._idx]
82+
if self._idx < len(self._statuses) - 1:
83+
self._idx += 1
84+
self.events.append((self.task_id, "status", status))
85+
return status
86+
87+
def download(self, path: PathLike):
88+
self.events.append((self.task_id, "download", str(path)))
89+
90+
@property
91+
def load_if_cached(self):
92+
return False
93+
94+
7295
class ImmediateExecutor:
7396
def __init__(self, *args, **kwargs):
7497
pass
@@ -728,32 +751,15 @@ def mock_start_interrupt(self, *args, **kwargs):
728751
def test_batch_monitor_downloads_on_success(monkeypatch, tmp_path):
729752
events = []
730753

731-
class FakeJob:
732-
def __init__(self, task_id: str, statuses: list[str]):
733-
self.task_id = task_id
734-
self._statuses = statuses
735-
self._idx = 0
736-
737-
@property
738-
def status(self):
739-
status = self._statuses[self._idx]
740-
if self._idx < len(self._statuses) - 1:
741-
self._idx += 1
742-
events.append((self.task_id, "status", status))
743-
return status
744-
745-
def download(self, path: PathLike):
746-
events.append((self.task_id, "download", str(path)))
747-
748754
monkeypatch.setattr("tidy3d.web.api.container.ThreadPoolExecutor", ImmediateExecutor)
749755
monkeypatch.setattr("tidy3d.web.api.container.time.sleep", lambda *_args, **_kwargs: None)
750756

751757
sims = {"task_a": make_sim(), "task_b": make_sim()}
752758
batch = Batch(simulations=sims, folder_name=PROJECT_NAME, verbose=False)
753759
batch._cached_properties = {}
754760
fake_jobs = {
755-
"task_a": FakeJob("task_a_id", ["running", "success", "success"]),
756-
"task_b": FakeJob("task_b_id", ["running", "running", "success"]),
761+
"task_a": FakeJob("task_a_id", ["running", "success", "success"], events),
762+
"task_b": FakeJob("task_b_id", ["running", "running", "success"], events),
757763
}
758764
batch._cached_properties["jobs"] = fake_jobs
759765

@@ -786,32 +792,15 @@ def download(self, path: PathLike):
786792
def test_batch_monitor_skips_existing_download(monkeypatch, tmp_path):
787793
events = []
788794

789-
class FakeJob:
790-
def __init__(self, task_id: str, statuses: list[str]):
791-
self.task_id = task_id
792-
self._statuses = statuses
793-
self._idx = 0
794-
795-
@property
796-
def status(self):
797-
status = self._statuses[self._idx]
798-
if self._idx < len(self._statuses) - 1:
799-
self._idx += 1
800-
events.append((self.task_id, "status", status))
801-
return status
802-
803-
def download(self, path: PathLike):
804-
events.append((self.task_id, "download", str(path)))
805-
806795
monkeypatch.setattr("tidy3d.web.api.container.ThreadPoolExecutor", ImmediateExecutor)
807796
monkeypatch.setattr("tidy3d.web.api.container.time.sleep", lambda *_args, **_kwargs: None)
808797

809798
sims = {"task_a": make_sim(), "task_b": make_sim()}
810799
batch = Batch(simulations=sims, folder_name=PROJECT_NAME, verbose=False)
811800
batch._cached_properties = {}
812801
fake_jobs = {
813-
"task_a": FakeJob("task_a_id", ["success", "success"]),
814-
"task_b": FakeJob("task_b_id", ["running", "success"]),
802+
"task_a": FakeJob("task_a_id", ["success", "success"], events),
803+
"task_b": FakeJob("task_b_id", ["running", "success"], events),
815804
}
816805
batch._cached_properties["jobs"] = fake_jobs
817806

tidy3d/web/api/autograd/autograd.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tidy3d.web.api.asynchronous import run_async as run_async_webapi
2020
from tidy3d.web.api.container import BatchData
2121
from tidy3d.web.api.tidy3d_stub import Tidy3dStub
22+
from tidy3d.web.api.webapi import load, restore_simulation_if_cached
2223
from tidy3d.web.api.webapi import run as run_webapi
2324
from tidy3d.web.core.types import PayType
2425

@@ -561,16 +562,31 @@ def _run_primitive(
561562
aux_data=aux_data,
562563
)
563564
else:
564-
sim_combined.validate_pre_upload()
565565
sim_original = sim_original.updated_copy(simulation_type="autograd_fwd", deep=False)
566-
run_kwargs["simulation_type"] = "autograd_fwd"
567-
run_kwargs["sim_fields_keys"] = list(sim_fields.keys())
568-
569-
sim_data_orig, task_id_fwd = _run_tidy3d(
570-
sim_original,
571-
task_name=task_name,
572-
**run_kwargs,
566+
restored_path, task_id_fwd = restore_simulation_if_cached(
567+
simulation=sim_original,
568+
path=run_kwargs.get("path", None),
569+
reduce_simulation=run_kwargs.get("reduce_simulation", "auto"),
570+
verbose=run_kwargs.get("verbose", True),
573571
)
572+
if restored_path is None or task_id_fwd is None:
573+
sim_combined.validate_pre_upload()
574+
run_kwargs["simulation_type"] = "autograd_fwd"
575+
run_kwargs["sim_fields_keys"] = list(sim_fields.keys())
576+
577+
sim_data_orig, task_id_fwd = _run_tidy3d(
578+
sim_original,
579+
task_name=task_name,
580+
**run_kwargs,
581+
)
582+
else:
583+
sim_data_orig = load(
584+
task_id=None,
585+
path=run_kwargs.get("path", None),
586+
verbose=run_kwargs.get("verbose", None),
587+
progress_callback=run_kwargs.get("progress_callback", None),
588+
lazy=run_kwargs.get("lazy", None),
589+
)
574590

575591
# TODO: put this in postprocess?
576592
aux_data[AUX_KEY_FWD_TASK_ID] = task_id_fwd

tidy3d/web/api/autograd/io_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import tidy3d as td
77
from tidy3d.components.autograd import AutogradFieldMap
88
from tidy3d.components.autograd.field_map import FieldMap, TracerKeys
9+
from tidy3d.web.api.webapi import get_info, load_simulation
10+
from tidy3d.web.cache import resolve_local_cache
911
from tidy3d.web.core.s3utils import download_file, upload_file # type: ignore
1012

1113
from .constants import SIM_FIELDS_KEYS_FILE, SIM_VJP_FILE
@@ -39,6 +41,21 @@ def get_vjp_traced_fields(task_id_adj: str, verbose: bool) -> AutogradFieldMap:
3941
try:
4042
download_file(task_id_adj, SIM_VJP_FILE, to_file=fname, verbose=verbose)
4143
field_map = FieldMap.from_file(fname)
44+
45+
simulation_cache = resolve_local_cache()
46+
if simulation_cache is not None:
47+
info = get_info(task_id_adj, verbose=False)
48+
workflow_type = getattr(info, "taskType", None)
49+
simulation = None
50+
with tempfile.NamedTemporaryFile(suffix=".hdf5") as tmp_file:
51+
simulation = load_simulation(task_id_adj, path=tmp_file.name, verbose=False)
52+
simulation_cache.store_result(
53+
stub_data=field_map,
54+
task_id=task_id_adj,
55+
path=fname,
56+
workflow_type=workflow_type,
57+
simulation=simulation,
58+
)
4259
except Exception as e:
4360
td.log.error(f"Error occurred while getting VJP traced fields: {e}")
4461
raise e

0 commit comments

Comments
 (0)