Skip to content

Commit

Permalink
WIP: max divergence padding factor
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-lauer committed Sep 4, 2024
1 parent 4b5da0a commit 214b3c1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/examples/wavefront.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@
" np.sum(wf, axis=1),\n",
" [8e19, 0.0, 0.0002, 0.0],\n",
" )\n",
" wfz[zi, :, :] = wf.copy()\n",
" wfz[zi, :, :] = wf\n",
" fwhmz_fit[zi] = FWHM\n",
"\n",
"wfz_wx = np.sum(wfz, axis=1)\n",
Expand Down
38 changes: 36 additions & 2 deletions pmd_beamphysics/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,32 @@ def create_gaussian_pulse_3d_with_q(
return pulse.astype(dtype)


def max_divergence_padding_factor(
theta_max: float,
drift_distance: float,
beam_size: float,
) -> float:
"""
Calculate the padding factor for the maximum divergence scenario.
Parameters
----------
theta_max : float
Maximum divergence [rad]
drift_distance : float
Drift propagation distance [m]
beam_size : float
Size of the beam at z=0 [m]
Returns
-------
float
Factor to increase the initial number of grid points, per dimension.
"""
# TODO: balticfish
return (theta_max * drift_distance) / beam_size


@dataclasses.dataclass(frozen=True)
class WavefrontPadding:
"""
Expand Down Expand Up @@ -775,6 +801,11 @@ def _check_metadata(self) -> None:
"each should describe the cartesian range of the corresponding axis."
)

if len(self.metadata.mesh.axis_labels) != len(self._field_rspace_shape):
raise ValueError(
"'axis_labels' must have the same number of dimensions as `field_rspace`"
)

def __copy__(self) -> Wavefront:
res = Wavefront.__new__(Wavefront)
res._phasors = self._phasors
Expand Down Expand Up @@ -1184,14 +1215,17 @@ def plot(

def plot(dat, title: str):
ax = remaining_axes.pop(0)
ax.imshow(np.sum(dat, axis=sum_axis), cmap=cmap)
img = ax.imshow(np.sum(dat, axis=sum_axis), cmap=cmap)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
if not ax.get_title():
ax.set_title(title)
images.append(img)
return img

images = []
if show_real:
plot(np.real(data), title="Real")

Expand All @@ -1212,7 +1246,7 @@ def plot(dat, title: str):
logger.info(f"Saving plot to {save!r}")
fig.savefig(save)

return fig, axs
return fig, axs, images

@classmethod
def from_genesis4(
Expand Down

0 comments on commit 214b3c1

Please sign in to comment.