From 214b3c1ecf441c07351544086085db15b17bcf89 Mon Sep 17 00:00:00 2001 From: Ken Lauer <152229072+ken-lauer@users.noreply.github.com> Date: Wed, 4 Sep 2024 13:58:22 -0700 Subject: [PATCH] WIP: max divergence padding factor --- docs/examples/wavefront.ipynb | 2 +- pmd_beamphysics/wavefront.py | 38 +++++++++++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/docs/examples/wavefront.ipynb b/docs/examples/wavefront.ipynb index 4d88965..64e1e5e 100644 --- a/docs/examples/wavefront.ipynb +++ b/docs/examples/wavefront.ipynb @@ -264,7 +264,7 @@ " np.sum(wf, axis=1),\n", " [8e19, 0.0, 0.0002, 0.0],\n", " )\n", - " wfz[zi, :, :] = wf.copy()\n", + " wfz[zi, :, :] = wf\n", " fwhmz_fit[zi] = FWHM\n", "\n", "wfz_wx = np.sum(wfz, axis=1)\n", diff --git a/pmd_beamphysics/wavefront.py b/pmd_beamphysics/wavefront.py index 78660b5..d22040c 100644 --- a/pmd_beamphysics/wavefront.py +++ b/pmd_beamphysics/wavefront.py @@ -542,6 +542,32 @@ def create_gaussian_pulse_3d_with_q( return pulse.astype(dtype) +def max_divergence_padding_factor( + theta_max: float, + drift_distance: float, + beam_size: float, +) -> float: + """ + Calculate the padding factor for the maximum divergence scenario. + + Parameters + ---------- + theta_max : float + Maximum divergence [rad] + drift_distance : float + Drift propagation distance [m] + beam_size : float + Size of the beam at z=0 [m] + + Returns + ------- + float + Factor to increase the initial number of grid points, per dimension. + """ + # TODO: balticfish + return (theta_max * drift_distance) / beam_size + + @dataclasses.dataclass(frozen=True) class WavefrontPadding: """ @@ -775,6 +801,11 @@ def _check_metadata(self) -> None: "each should describe the cartesian range of the corresponding axis." ) + if len(self.metadata.mesh.axis_labels) != len(self._field_rspace_shape): + raise ValueError( + "'axis_labels' must have the same number of dimensions as `field_rspace`" + ) + def __copy__(self) -> Wavefront: res = Wavefront.__new__(Wavefront) res._phasors = self._phasors @@ -1184,14 +1215,17 @@ def plot( def plot(dat, title: str): ax = remaining_axes.pop(0) - ax.imshow(np.sum(dat, axis=sum_axis), cmap=cmap) + img = ax.imshow(np.sum(dat, axis=sum_axis), cmap=cmap) if xlim is not None: ax.set_xlim(xlim) if ylim is not None: ax.set_ylim(ylim) if not ax.get_title(): ax.set_title(title) + images.append(img) + return img + images = [] if show_real: plot(np.real(data), title="Real") @@ -1212,7 +1246,7 @@ def plot(dat, title: str): logger.info(f"Saving plot to {save!r}") fig.savefig(save) - return fig, axs + return fig, axs, images @classmethod def from_genesis4(