77from pathlib import Path
88from types import SimpleNamespace
99
10+ import autograd as ag
1011import pytest
12+ import xarray as xr
13+ from autograd .core import defvjp
1114from rich .console import Console
1215
1316import tidy3d as td
1417from tests .test_components .autograd .test_autograd import ALL_KEY , get_functions , params0
1518from tests .test_web .test_webapi_mode import make_mode_sim
1619from tidy3d import config
20+ from tidy3d .components .autograd .field_map import FieldMap
1721from tidy3d .config import get_manager
1822from tidy3d .web import Job , common , run , run_async
1923from tidy3d .web .api import webapi as web
24+ from tidy3d .web .api .autograd import autograd , engine , io_utils
25+ from tidy3d .web .api .autograd .autograd import run as run_autograd
26+ from tidy3d .web .api .autograd .constants import SIM_VJP_FILE
2027from tidy3d .web .api .container import Batch , WebContainer
2128from tidy3d .web .api .webapi import load_simulation_if_cached
2229from tidy3d .web .cache import CACHE_ARTIFACT_NAME , clear , get_cache_entry_dir , resolve_local_cache
@@ -39,6 +46,24 @@ class _FakeStubData:
3946 def __init__ (self , simulation : td .Simulation ):
4047 self .simulation = simulation
4148
49+ def __getitem__ (self , key ):
50+ if key == "mode" :
51+ params = self .simulation .attrs ["params_autograd" ]
52+ return SimpleNamespace (
53+ amps = xr .DataArray (params , dims = ["x" ], coords = {"x" : list (range (len (params )))})
54+ )
55+
56+ def _strip_traced_fields (self , * args , ** kwargs ):
57+ """Fake _strip_traced_fields: return minimal valid autograd-style mapping."""
58+ return {"params" : self .simulation .attrs ["params" ]}
59+
60+ def _insert_traced_fields (self , field_mapping , * args , ** kwargs ):
61+ self .simulation .attrs ["params_autograd" ] = field_mapping ["params" ]
62+ return self
63+
64+ def _make_adjoint_sims (self , ** kwargs ):
65+ return [self .simulation .updated_copy (run_time = self .simulation .run_time * 2 )]
66+
4267
4368@pytest .fixture
4469def basic_simulation ():
@@ -128,25 +153,50 @@ def _fake__check_folder(*args, **kwargs):
128153 def _fake_status (self ):
129154 return "success"
130155
156+ def _fake_download_file (resource_id , remote_filename , to_file = None , ** kwargs ):
157+ # Only count this download if it's the adjoint/VJP file
158+ if str (remote_filename ) == SIM_VJP_FILE :
159+ counters ["download" ] += 1
160+
161+ def _fake_from_file (* args , ** kwargs ):
162+ field_map = FieldMap (tracers = ())
163+ return field_map
164+
165+ monkeypatch .setattr (io_utils , "download_file" , _fake_download_file )
166+ monkeypatch .setattr (autograd , "postprocess_fwd" , _fake_from_file )
167+ monkeypatch .setattr (FieldMap , "from_file" , _fake_from_file )
131168 monkeypatch .setattr (WebContainer , "_check_folder" , _fake__check_folder )
132169 monkeypatch .setattr (web , "upload" , _fake_upload )
133170 monkeypatch .setattr (web , "start" , _fake_start )
134171 monkeypatch .setattr (web , "monitor" , _fake_monitor )
135172 monkeypatch .setattr (web , "download" , _fake_download )
136173 monkeypatch .setattr (web , "estimate_cost" , lambda * args , ** kwargs : 0.0 )
137174 monkeypatch .setattr (Job , "status" , property (_fake_status ))
175+ monkeypatch .setattr (engine , "upload_sim_fields_keys" , lambda * args , ** kwargs : None )
138176 monkeypatch .setattr (
139177 web ,
140178 "get_info" ,
141179 lambda task_id , verbose = True : type (
142180 "_Info" , (), {"solverVersion" : "solver-1" , "taskType" : "FDTD" }
143181 )(),
144182 )
183+ monkeypatch .setattr (
184+ io_utils ,
185+ "get_info" ,
186+ lambda task_id , verbose = True : type (
187+ "_Info" , (), {"solverVersion" : "solver-1" , "taskType" : "FDTD" }
188+ )(),
189+ )
145190 monkeypatch .setattr (
146191 web , "load_simulation" , lambda task_id , * args , ** kwargs : TASK_TO_SIM [task_id ]
147192 )
193+ monkeypatch .setattr (
194+ io_utils , "load_simulation" , lambda task_id , * args , ** kwargs : TASK_TO_SIM [task_id ]
195+ )
148196 monkeypatch .setattr (BatchTask , "is_batch" , lambda * args , ** kwargs : "success" )
149- monkeypatch .setattr (BatchTask , "detail" , lambda * args : SimpleNamespace (status = "success" ))
197+ monkeypatch .setattr (
198+ BatchTask , "detail" , lambda * args , ** kwargs : SimpleNamespace (status = "success" )
199+ )
150200 return counters
151201
152202
@@ -371,23 +421,59 @@ def _test_job_run_cache(monkeypatch, basic_simulation, tmp_path):
371421 assert os .path .exists (out2_path )
372422
373423
374- def _test_autograd_cache (monkeypatch ):
424+ def _test_autograd_cache (monkeypatch , request ):
375425 counters = _patch_run_pipeline (monkeypatch )
426+
427+ # "Original" rule: the one autograd uses by default
428+ def _orig_make_dict_vjp (ans , keys , vals ):
429+ return lambda g : [g [key ] for key in keys ]
430+
431+ def _zero_make_dict_vjp (ans , keys , vals ):
432+ def vjp (g ):
433+ # One gradient per entry in `vals`, all zeros, matching shape/dtype
434+ return [ag .numpy .zeros_like (v ) for v in vals ]
435+
436+ return vjp
437+
438+ # Install our zero-VJP (this is the thing that affects global state)
439+ defvjp (
440+ ag .builtins ._make_dict ,
441+ _zero_make_dict_vjp ,
442+ argnums = (1 ,), # gradient w.r.t. `vals`
443+ )
444+
445+ # Make sure we restore it after the test
446+ def _restore_make_dict_vjp ():
447+ defvjp (
448+ ag .builtins ._make_dict ,
449+ _orig_make_dict_vjp ,
450+ argnums = (1 ,),
451+ )
452+
453+ request .addfinalizer (_restore_make_dict_vjp )
454+
376455 cache = resolve_local_cache (use_cache = True )
377456 cache .clear ()
378457
379458 functions = get_functions (ALL_KEY , "mode" )
380459 make_sim = functions ["sim" ]
381- sim = make_sim (params0 )
382- web .run (sim )
383- assert counters ["download" ] == 1
384- assert len (cache ) == 1
460+ postprocess = functions ["postprocess" ]
461+
462+ def objective (params ):
463+ sim = make_sim (params )
464+ sim .attrs ["params" ] = params
465+ sim_data = run_autograd (sim )
466+ value = postprocess (sim_data )
467+ return value
468+
469+ ag .value_and_grad (objective )(params0 )
470+ assert counters ["download" ] == 2
471+ assert len (cache ) == 2
385472
386473 _reset_counters (counters )
387- sim = make_sim (params0 )
388- web .run (sim )
389- assert counters ["download" ] == 0
390- assert len (cache ) == 1
474+ ag .value_and_grad (objective )(params0 )
475+ assert counters ["download" ] == 1 # download field data
476+ assert len (cache ) == 2
391477
392478
393479def _test_load_cache_hit (monkeypatch , tmp_path , basic_simulation , fake_data ):
@@ -499,7 +585,9 @@ def _test_env_var_overrides(monkeypatch, tmp_path):
499585 manager ._reload ()
500586
501587
502- def test_cache_sequential (monkeypatch , tmp_path , tmp_path_factory , basic_simulation , fake_data ):
588+ def test_cache_sequential (
589+ monkeypatch , tmp_path , tmp_path_factory , basic_simulation , fake_data , request
590+ ):
503591 """Run all critical cache tests in sequence to ensure stability."""
504592 monkeypatch .setattr (config .local_cache , "enabled" , True )
505593
@@ -514,7 +602,7 @@ def test_cache_sequential(monkeypatch, tmp_path, tmp_path_factory, basic_simulat
514602 _test_cache_eviction_by_size (monkeypatch , tmp_path_factory , basic_simulation )
515603 _test_run_cache_hit_async (monkeypatch , basic_simulation , tmp_path )
516604 _test_job_run_cache (monkeypatch , basic_simulation , tmp_path )
517- _test_autograd_cache (monkeypatch )
605+ _test_autograd_cache (monkeypatch , request )
518606 _test_configure_cache_roundtrip (monkeypatch , tmp_path )
519607 _test_mode_solver_caching (monkeypatch , tmp_path )
520608 _test_verbosity (monkeypatch , basic_simulation )
0 commit comments