Skip to content

Commit

Permalink
WIP: moving from time[fs]->z[m] and fixing along the way
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-lauer committed Sep 12, 2024
1 parent bf40ceb commit 6102467
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 52 deletions.
42 changes: 11 additions & 31 deletions docs/examples/wavefront.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@
"W = Wavefront.gaussian_pulse(\n",
" dims=(101, 101, 801),\n",
" wavelength=1.35e-8,\n",
" grid_spacing=(6e-6, 6e-6, 0.0625),\n",
" grid_spacing=(6e-6, 6e-6, 1e-7),\n",
" pad=(100, 100, 40),\n",
" nphotons=1e12,\n",
" zR=2.0,\n",
" sigma_t=5,\n",
" sigma_z=3e-6,\n",
")"
]
},
Expand All @@ -71,26 +71,6 @@
"W.plot(\"xy\", rspace=False);"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09b8141f-35bc-4a87-ba5e-da688f0ff763",
"metadata": {},
"outputs": [],
"source": [
"nt, nx, ny = np.shape(W.rmesh)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "237796aa-4d23-462f-88ce-299cea8e2eed",
"metadata": {},
"outputs": [],
"source": [
"print(W.rmesh.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -117,7 +97,7 @@
"source": [
"plt.figure(figsize=(4, 2))\n",
"plt.plot(W.rspace_domain[2], np.abs(get_longitudinal_slice(W.rmesh)))\n",
"plt.xlabel(\"Time [fs]\");"
"plt.xlabel(\"Position z [m]\");"
]
},
{
Expand Down Expand Up @@ -240,16 +220,16 @@
"X = Wavefront.gaussian_pulse(\n",
" dims=(101, 101, 801),\n",
" wavelength=1.35e-8,\n",
" grid_spacing=(6e-6, 6e-6, 0.0625),\n",
" grid_spacing=(6e-6, 6e-6, 1e-7),\n",
" pad=(100, 100, 40),\n",
" nphotons=1e12,\n",
" zR=zR,\n",
" sigma_t=5,\n",
" zR=2.0,\n",
" sigma_z=3e-6,\n",
")\n",
"\n",
"mz = 3\n",
"zgrid = 5\n",
"dz = 0.25\n",
"mz = 3 # mz * zgrid = number of steps\n",
"zgrid = 5 #\n",
"dz = 0.25 # drift distance [m]\n",
"\n",
"nx, ny = get_transverse_slice(X.rmesh).shape\n",
"wfz = np.zeros((mz * zgrid, nx, ny))\n",
Expand Down Expand Up @@ -277,7 +257,7 @@
" popt_gaussian, ydata_fit, FWHM, roots = gaussian_fit(\n",
" domain_x,\n",
" np.sum(wf, axis=1),\n",
" [8e19, 0.0, 0.0002, 0.0],\n",
" [8e24, 0.0, 0.0002, 0.0],\n",
" )\n",
" wfz[zi, :, :] = wf\n",
" fwhmz_fit[zi] = FWHM\n",
Expand All @@ -302,7 +282,7 @@
" extent=[-mz * zgrid * dz, mz * zgrid * dz, ymin, ymax],\n",
" aspect=\"auto\",\n",
" vmin=0.0,\n",
" vmax=1.6e20,\n",
" vmax=1.6e26,\n",
")\n",
"\n",
"plt.plot(z, wz, \"-\", linewidth=2, color=\"white\")\n",
Expand Down
77 changes: 56 additions & 21 deletions pmd_beamphysics/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def fix_padding(


def get_shifts(
dims: Sequence[int],
ranges: Ranges,
pads: Sequence[int],
deltas: Sequence[float],
Expand All @@ -336,6 +337,8 @@ def get_shifts(
Parameters
----------
dims :
Grid dimensions
ranges : tuple of (float, float) pairs
Low and high domain range for each dimension of the wavefront.
pads : tuple of ints
Expand All @@ -351,8 +354,17 @@ def get_shifts(

assert len(pads) == len(deltas) > 1
spans = tuple(domain[1] - domain[0] for domain in ranges)

def fix_even(dim: int, delta: float) -> float:
# For odd number of grid points, no fix is required
if dim % 2 == 1:
return 0.0
# For an even number of grid points, we add half a grid step
return delta / 2.0

return tuple(
span / 2.0 + pad * delta for span, pad, delta in zip(spans, pads, deltas)
span / 2.0 + pad * delta + fix_even(dim, delta)
for dim, span, pad, delta in zip(dims, spans, pads, deltas)
)


Expand All @@ -365,10 +377,10 @@ def conversion_coeffs(wavelength: float, dim: int) -> Tuple[float, ...]:
"""
Conversion coefficients to (eV, radians, radians, ...).
Omega, theta-x, theta-y.
Theta-x, theta-y, omega.
"""
k0 = calculate_k0(wavelength)
hbar = scipy.constants.hbar / scipy.constants.e * 1.0e15 # fs-eV
hbar = scipy.constants.hbar / scipy.constants.e * scipy.constants.c # eV-m
return tuple([2.0 * np.pi / k0] * (dim - 1) + [2.0 * np.pi * hbar])


Expand Down Expand Up @@ -463,7 +475,7 @@ def create_gaussian_pulse_3d_with_q(
wavelength: float,
nphotons: float,
zR: float,
sigma_t: float,
sigma_z: float,
grid_spacing: Sequence[float],
grid: Sequence[int],
dtype=np.complex64,
Expand All @@ -474,13 +486,13 @@ def create_gaussian_pulse_3d_with_q(
Parameters
----------
wavelength : float
Wavelength (lambda0) [m].
Wavelength (lambda0). [m]
nphotons : float
Number of photons.
zR : float
Rayleigh range [m].
sigma_t : float
Time RMS [s]
Rayleigh range. [m]
sigma_z : float
Pulse length RMS. [m]
grid_spacing : tuple of floats
Per-axis grid spacing.
grid : tuple of ints
Expand All @@ -497,23 +509,23 @@ def create_gaussian_pulse_3d_with_q(
raise ValueError("`grid` must be of length 3 for a 3D gaussian")

ranges = get_ranges_for_grid_spacing(grid_spacing=grid_spacing, dims=grid)
min_t, max_t = ranges[-1]
min_z, max_z = ranges[-1]

k0 = calculate_k0(wavelength)
t_mid = (max_t + min_t) / 2.0
x_mesh, y_mesh, t_mesh = nd_space_mesh(ranges=ranges, sizes=grid)
z_mid = (max_z + min_z) / 2.0
x_mesh, y_mesh, z_mesh = nd_space_mesh(ranges=ranges, sizes=grid)
qx = 1j * zR
qy = 1j * zR

ux = 1.0 / np.sqrt(qx) * np.exp(-1j * k0 * x_mesh**2 / 2.0 / qx)
uy = 1.0 / np.sqrt(qy) * np.exp(-1j * k0 * y_mesh**2 / 2.0 / qy)
ut = (1.0 / (np.sqrt(2.0 * np.pi) * sigma_t)) * np.exp(
-((t_mesh - t_mid) ** 2) / 2.0 / sigma_t**2
uz = (1.0 / (np.sqrt(2.0 * np.pi) * sigma_z)) * np.exp(
-((z_mesh - z_mid) ** 2) / 2.0 / sigma_z**2
)

eta = 2.0 * k0 * zR * sigma_t / np.sqrt(np.pi)
eta = 2.0 * k0 * zR * sigma_z / np.sqrt(np.pi)

pulse = np.sqrt(eta) * np.sqrt(nphotons) * ux * uy * ut
pulse = np.sqrt(eta) * np.sqrt(nphotons) * ux * uy * uz
return pulse.astype(dtype)


Expand Down Expand Up @@ -835,7 +847,7 @@ def gaussian_pulse(
wavelength: float,
nphotons: float,
zR: float,
sigma_t: float,
sigma_z: float,
grid_spacing: Sequence[float],
pad: Optional[Sequence[int]] = None,
dtype=np.complex64,
Expand All @@ -854,8 +866,8 @@ def gaussian_pulse(
Number of photons.
zR : float
Rayleigh range [m].
sigma_t : float
Time RMS [s]
sigma_z : float
Pulse length RMS. [m]
Returns
-------
Expand All @@ -865,7 +877,7 @@ def gaussian_pulse(
wavelength=wavelength,
nphotons=nphotons,
zR=zR,
sigma_t=sigma_t,
sigma_z=sigma_z,
grid_spacing=grid_spacing,
grid=dims,
dtype=dtype,
Expand All @@ -889,13 +901,35 @@ def rspace_domain(self):
"""
return cartesian_domain(ranges=self.ranges, grids=self._pad.grid)

@property
def kspace_domain(self):
"""
Reciprocal space domain values in all dimensions.
For each dimension of the wavefront, this is the evenly-spaced set of values over
its specified range.
"""

coeffs = conversion_coeffs(
wavelength=self.wavelength,
dim=len(self._rmesh_shape),
)
return nd_kspace_domains(
coeffs=coeffs,
sizes=self._rmesh_shape,
pads=self.pad.pad,
steps=self.grid_spacing,
shifted=True,
)

def _calc_phasors(self) -> Tuple[np.ndarray, ...]:
"""Calculate phasors for each dimension of the cartesian domain."""
coeffs = conversion_coeffs(
wavelength=self._wavelength,
dim=len(self._rmesh_shape),
)
shifts = get_shifts(
dims=self._pad.grid,
ranges=self.ranges,
pads=self._pad.pad,
deltas=self.grid_spacing,
Expand Down Expand Up @@ -1185,6 +1219,7 @@ def plot(
data = self.rmesh
else:
data = self.kmesh
# TODO change labels with prefix of 'theta'

if transpose:
data = data.T
Expand All @@ -1211,7 +1246,7 @@ def plot(

def plot(dat, title: str):
ax = remaining_axes.pop(0)
img = ax.imshow(np.sum(dat, axis=sum_axis), cmap=cmap)
img = ax.imshow(np.mean(dat, axis=sum_axis), cmap=cmap)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
Expand All @@ -1229,7 +1264,7 @@ def plot(dat, title: str):
plot(np.imag(data), title="Imaginary")

if show_abs:
plot(np.abs(data), f"|{plane}|")
plot(np.abs(data) ** 2, f"|{plane}|**2")

if show_phase:
plot(np.angle(data), title="Phase")
Expand Down

0 comments on commit 6102467

Please sign in to comment.