Skip to content

Commit

Permalink
refc[adjoint]: filter zero VJPs in setup_adj
Browse files Browse the repository at this point in the history
  • Loading branch information
yaugenst-flex committed Feb 28, 2025
1 parent e8af623 commit 3708817
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def objective(*args):
sim_data = run(sim, task_name="adjoint_test", verbose=False)
return 0 * postprocess(sim_data)

with AssertLogLevel("WARNING", contains_str="fields are zero"):
with AssertLogLevel("WARNING", contains_str="no sources"):
grad = ag.grad(objective)(params0)


Expand Down
16 changes: 8 additions & 8 deletions tidy3d/web/api/autograd/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,12 +626,6 @@ def _run_bwd(
def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
"""dJ/d{sim.traced_fields()} as a function of Function of dJ/d{data.traced_fields()}"""

# filter out any data_fields_vjp with all 0's and skip the corresponding adjoint simulation
data_fields_vjp = {k: v for k, v in data_fields_vjp.items() if not np.allclose(v, 0)}
if not data_fields_vjp:
td.log.warning("All VJP fields are zero, skipping adjoint simulation")
return {k: 0 * v for k, v in sim_fields_original.items()}

# build the (possibly multiple) adjoint simulations
sims_adj = setup_adj(
data_fields_vjp=data_fields_vjp,
Expand Down Expand Up @@ -858,8 +852,14 @@ def setup_adj(

td.log.info("Running custom vjp (adjoint) pipeline.")

# immediately filter out any data_vjps with all 0's in the data
data_fields_vjp = {key: get_static(value) for key, value in data_fields_vjp.items()}
# filter out any data_fields_vjp with all 0's
data_fields_vjp = {
k: get_static(v) for k, v in data_fields_vjp.items() if not np.allclose(v, 0)
}

# if all entries are zero, there is no adjoint sim to run
if not data_fields_vjp:
return []

# start with the full simulation data structure and either zero out the fields
# that have no tracer data for them or insert the tracer data
Expand Down

0 comments on commit 3708817

Please sign in to comment.