diff --git a/docs/examples/wavefront.ipynb b/docs/examples/wavefront.ipynb index 6d6c076..d66b894 100644 --- a/docs/examples/wavefront.ipynb +++ b/docs/examples/wavefront.ipynb @@ -273,7 +273,7 @@ } ], "source": [ - "propagated_w = W.propagate_z(3)\n", + "propagated_w = W.propagate(\"z\", 3)\n", "fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 4))\n", "\n", "rspace_abs_orig = np.abs(W.field_rspace[int(nt / 2), :, :]) ** 2\n", @@ -288,14 +288,6 @@ "ax2.set_title(\"Propagated to 3m\");" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "45df1eb1-7c24-4f93-b4b4-0d1d1fcc4d65", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": 12, @@ -330,7 +322,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/vy/s8_hc3m10fddm6n_43cf_m8r0000gn/T/ipykernel_6966/3266762373.py:5: OptimizeWarning: Covariance of the parameters could not be estimated\n", + "/var/folders/vy/s8_hc3m10fddm6n_43cf_m8r0000gn/T/ipykernel_9525/3266762373.py:5: OptimizeWarning: Covariance of the parameters could not be estimated\n", " popt_gaussian, pcov_gaussian = curve_fit(gaussian_func, xdata, ydata, p0=initial_guess)\n" ] }, @@ -385,12 +377,12 @@ "\n", "domain_x = X.rspace_domain[1]\n", "\n", - "X.propagate_z(0.0, inplace=True)\n", + "X.propagate(\"z\", 0.0, inplace=True)\n", "\n", "for zi in range(0, mz * zgrid):\n", " if zi > 0:\n", " print(\"Propagating to: \", zi * dz)\n", - " X.propagate_z(dz, inplace=True)\n", + " X.propagate(\"z\", dz, inplace=True)\n", " wf = np.abs(X.field_rspace[int(nt / 2), :, :]) ** 2\n", " \n", " popt_gaussian, ydata_fit, FWHM, roots = gaussian_fit(\n", diff --git a/pmd_beamphysics/wavefront.py b/pmd_beamphysics/wavefront.py index f9f4033..29b3c0c 100644 --- a/pmd_beamphysics/wavefront.py +++ b/pmd_beamphysics/wavefront.py @@ -5,7 +5,7 @@ import copy import dataclasses import logging -from typing import Any, List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import scipy.constants @@ -425,7 +425,7 @@ def domains_kxky( ) -def drift_kernel( +def drift_kernel_z( domains_kxky: List[np.ndarray], z: float, wavelength: float, @@ -435,19 +435,19 @@ def drift_kernel( return np.exp(-1j * z * np.pi * wavelength * (kx**2 + ky**2)) -def drift_propagator( +def drift_propagator_z( field_kspace: np.ndarray, domains_kxky: List[np.ndarray], z: float, wavelength: float, ): """Fresnel propagator in paraxial approximation to distance z [m].""" - return field_kspace * drift_kernel( + return field_kspace * drift_kernel_z( domains_kxky=domains_kxky, z=z, wavelength=wavelength ) -def thin_lens_kernel( +def thin_lens_kernel_xy( wavelength: float, ranges: Ranges, grid: Sequence[int], @@ -638,6 +638,93 @@ def __init__( self._ranges = tuple(ranges) self._pad = WavefrontPadding.from_array(field_rspace, pad=pad, fix=True) + def __copy__(self) -> Wavefront: + res = Wavefront.__new__(Wavefront) + res._phasors = self._phasors + res._field_rspace_shape = self._field_rspace_shape + res._field_rspace = self._field_rspace + res._field_kspace = self._field_kspace + res._wavelength = self._wavelength + res._ranges = self._ranges + res._pad = self._pad + return res + + def __deepcopy__(self, memo) -> Wavefront: + res = Wavefront.__new__(Wavefront) + res._phasors = self._phasors + res._field_rspace_shape = self._field_rspace_shape + res._field_rspace = ( + np.copy(self._field_rspace) if self._field_rspace is not None else None + ) + res._field_kspace = ( + np.copy(self._field_kspace) if self._field_kspace is not None else None + ) + res._wavelength = self._wavelength + res._ranges = self._ranges + res._pad = self._pad + return res + + def __eq__(self, other: Any) -> bool: + if type(self) is not type(other): + return False + return all( + ( + self._field_rspace_shape == other._field_rspace_shape, + np.all(self._field_rspace == other._field_rspace), + np.all(self._field_kspace == other._field_kspace), + self._wavelength == other._wavelength, + self._ranges == other._ranges, + self._pad == other._pad, + ) + ) + + @classmethod + def gaussian_pulse( + cls, + dims: Tuple[int, int, int], + wavelength: float, + nphotons: float, + zR: float, + sigma_t: float, + ranges: Ranges, + pad: Optional[Sequence[int]] = None, + dtype=np.complex64, + ): + """ + Generate a complex three-dimensional spatio-temporal Gaussian profile + in terms of the q parameter. + + Parameters + ---------- + wavelength : float + Wavelength (lambda0) [m]. + nphotons : float + Number of photons. + zR : float + Rayleigh range [m]. + sigma_t : float + Time RMS [s] + + Returns + ------- + Wavefront + """ + pulse = create_gaussian_pulse_3d_with_q( + wavelength=wavelength, + nphotons=nphotons, + zR=zR, + sigma_t=sigma_t, + ranges=ranges, + grid=dims, + dtype=dtype, + ) + return cls( + field_rspace=pulse, + wavelength=wavelength, + ranges=ranges, + pad=pad, + ) + @property def rspace_domain(self): """ @@ -739,41 +826,6 @@ def _ifft(self): workers=workers, )[self._pad.ifft_slices] - def propagate_z(self, z_prop: float, *, inplace: bool = False) -> Wavefront: - """ - Propagate this Wavefront in-place along Z in meters. - - Parameters - ---------- - z_prop : float - Distance in meters. - inplace : bool, default=False - Perform the operation in-place on this wavefront object. - - Returns - ------- - Wavefront - This object if `inplace=True` or a new copy if `inplace=False`. - """ - if not inplace: - wavefront = copy.copy(self) - return wavefront.propagate_z(z_prop, inplace=True) - - z_prop = float(z_prop) - self._field_kspace = drift_propagator( - field_kspace=self.field_kspace, - domains_kxky=domains_kxky( - grids=self._pad.grid, - pads=self._pad.pad, - deltas=self._rspace_deltas, - ), - wavelength=self._wavelength, - z=z_prop, - ) - # Invalidate the real space data - self._field_rspace = None - return self - @property def wavelength(self) -> float: """Wavelength of the wavefront [m].""" @@ -788,10 +840,10 @@ def pad(self): def ranges(self): return self._ranges - def focusing_element( + def focus( self, - f_lens_x: float, - f_lens_y: float, + plane: Union[str, Tuple[int, int]], + focus: Tuple[float, float], *, inplace: bool = False, ) -> Wavefront: @@ -800,10 +852,10 @@ def focusing_element( Parameters ---------- - f_lens_x : float - Focal length of the lens in x [m]. - f_lens_y : float - Focal length of the lens in y [m]. + plane : str or (int, int) + Plane identifier (e.g., "xy") or dimension indices (e.g., ``(1, 2)``) + focus : (float, float) + Focal length of the lens in each dimension [m]. inplace : bool, default=False Perform the operation in-place on this wavefront object. @@ -812,104 +864,66 @@ def focusing_element( Wavefront This object if `inplace=True` or a new copy if `inplace=False`. """ + if plane not in ("xy", (1, 2)): + raise NotImplementedError(f"Unsupported plane: {plane}") + if not inplace: wavefront = copy.copy(self) - return wavefront.focusing_element(f_lens_x, f_lens_y, inplace=True) + return wavefront.focus(plane, focus, inplace=True) - self._field_rspace = self.field_rspace * thin_lens_kernel( + self._field_rspace = self.field_rspace * thin_lens_kernel_xy( wavelength=self.wavelength, ranges=self._ranges, grid=self._pad.grid, - f_lens_x=f_lens_x, - f_lens_y=f_lens_y, + f_lens_x=focus[0], + f_lens_y=focus[1], ) # Invalidate the spectral data self._field_kspace = None return self - def __deepcopy__(self, memo) -> Wavefront: - res = Wavefront.__new__(Wavefront) - res._phasors = self._phasors - res._field_rspace_shape = self._field_rspace_shape - res._field_rspace = ( - np.copy(self._field_rspace) if self._field_rspace is not None else None - ) - res._field_kspace = ( - np.copy(self._field_kspace) if self._field_kspace is not None else None - ) - res._wavelength = self._wavelength - res._ranges = self._ranges - res._pad = self._pad - return res - - def __copy__(self) -> Wavefront: - res = Wavefront.__new__(Wavefront) - res._phasors = self._phasors - res._field_rspace_shape = self._field_rspace_shape - res._field_rspace = self._field_rspace - res._field_kspace = self._field_kspace - res._wavelength = self._wavelength - res._ranges = self._ranges - res._pad = self._pad - return res - - def __eq__(self, other: Any) -> bool: - if type(self) is not type(other): - return False - return all( - ( - self._field_rspace_shape == other._field_rspace_shape, - np.all(self._field_rspace == other._field_rspace), - np.all(self._field_kspace == other._field_kspace), - self._wavelength == other._wavelength, - self._ranges == other._ranges, - self._pad == other._pad, - ) - ) - - @classmethod - def gaussian_pulse( - cls, - dims: Tuple[int, int, int], - wavelength: float, - nphotons: float, - zR: float, - sigma_t: float, - ranges: Ranges, - pad: Optional[Sequence[int]] = None, - dtype=np.complex64, - ): + def propagate( + self, + direction: Union[str, int], + distance: float, + *, + inplace: bool = False, + ) -> Wavefront: """ - Generate a complex three-dimensional spatio-temporal Gaussian profile - in terms of the q parameter. + Propagate this Wavefront in-place along Z in meters. Parameters ---------- - wavelength : float - Wavelength (lambda0) [m]. - nphotons : float - Number of photons. - zR : float - Rayleigh range [m]. - sigma_t : float - Time RMS [s] + direction : str or (int, int) + Propagation direction dimension name (e.g., "z") or dimension index (e.g., `2`) + z_prop : float + Distance in meters. + inplace : bool, default=False + Perform the operation in-place on this wavefront object. Returns ------- Wavefront + This object if `inplace=True` or a new copy if `inplace=False`. """ - pulse = create_gaussian_pulse_3d_with_q( - wavelength=wavelength, - nphotons=nphotons, - zR=zR, - sigma_t=sigma_t, - ranges=ranges, - grid=dims, - dtype=dtype, - ) - return cls( - field_rspace=pulse, - wavelength=wavelength, - ranges=ranges, - pad=pad, + + if direction not in {"z", 2}: + raise NotImplementedError(f"Unsupported propagation direction: {direction}") + + if not inplace: + wavefront = copy.copy(self) + return wavefront.propagate(direction, distance, inplace=True) + + self._field_kspace = drift_propagator_z( + field_kspace=self.field_kspace, + domains_kxky=domains_kxky( + grids=self._pad.grid, + pads=self._pad.pad, + deltas=self._rspace_deltas, + ), + wavelength=self._wavelength, + z=float(distance), ) + # Invalidate the real space data + self._field_rspace = None + return self diff --git a/tests/test_wavefront.py b/tests/test_wavefront.py index d813793..49b6ce6 100644 --- a/tests/test_wavefront.py +++ b/tests/test_wavefront.py @@ -32,22 +32,22 @@ def wavefront() -> Wavefront: def test_smoke_propagate_z_in_place(wavefront: Wavefront) -> None: # Implicitly calculates the FFT: - wavefront.propagate_z(0.0, inplace=True) + wavefront.propagate(direction="z", distance=0.0, inplace=True) # Use the property to calculate the inverse fft: wavefront.field_rspace def test_smoke_propagate_z(wavefront: Wavefront) -> None: - new = wavefront.propagate_z(0.0, inplace=False) + new = wavefront.propagate(direction="z", distance=0.0, inplace=False) assert new is not wavefront def test_smoke_focusing_element_in_place(wavefront: Wavefront) -> None: - wavefront.focusing_element(1.0, 1.0, inplace=True) + wavefront.focus(plane="xy", focus=(1.0, 1.0), inplace=True) def test_smoke_focusing_element(wavefront: Wavefront) -> None: - new = wavefront.focusing_element(1.0, 1.0, inplace=False) + new = wavefront.focus(plane="xy", focus=(1.0, 1.0), inplace=False) assert new is not wavefront