Skip to content

Commit

Permalink
REF: inplace kwarg instead
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-lauer committed Aug 14, 2024
1 parent 89c4268 commit db98ddb
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 97 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ jobs:
build:
runs-on: ${{ matrix.os }}
defaults:
run:
shell: bash
run:
shell: bash
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.9]
python-version: [3.10]
steps:
- uses: actions/checkout@v3
- uses: conda-incubator/setup-miniconda@v2
Expand All @@ -21,11 +21,11 @@ jobs:
channels: conda-forge
activate-environment: pmd_beamphysics-dev
environment-file: environment.yml
# - name: flake8
# shell: bash -l {0}
# run: |
# flake8 .

# - name: flake8
# shell: bash -l {0}
# run: |
# flake8 .

- name: Install openPMD-beamphysics
shell: bash -l {0}
Expand Down
10 changes: 5 additions & 5 deletions docs/examples/wavefront.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"from scipy.optimize import curve_fit\n",
"from scipy.interpolate import UnivariateSpline\n",
"\n",
"from pmd_beamphysics.wavefront import Wavefront, propagate_z"
"from pmd_beamphysics.wavefront import Wavefront"
]
},
{
Expand Down Expand Up @@ -251,7 +251,7 @@
}
],
"source": [
"propagated_x = propagate_z(X, 3)\n",
"propagated_x = X.propagate_z(3)\n",
"_, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 4))\n",
"\n",
"ax1.imshow(np.abs(X.field_rspace[int(nt / 2), :, :]) ** 2)\n",
Expand Down Expand Up @@ -292,7 +292,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/vy/s8_hc3m10fddm6n_43cf_m8r0000gn/T/ipykernel_52096/3266762373.py:5: OptimizeWarning: Covariance of the parameters could not be estimated\n",
"/var/folders/vy/s8_hc3m10fddm6n_43cf_m8r0000gn/T/ipykernel_65481/3266762373.py:5: OptimizeWarning: Covariance of the parameters could not be estimated\n",
" popt_gaussian, pcov_gaussian = curve_fit(gaussian_func, xdata, ydata, p0=initial_guess)\n"
]
},
Expand Down Expand Up @@ -347,12 +347,12 @@
"\n",
"domain_x = X.rspace_domain[1]\n",
"\n",
"X.propagate_z(0.0)\n",
"X.propagate_z(0.0, inplace=True)\n",
"\n",
"for zi in range(0, mz * zgrid):\n",
" if zi > 0:\n",
" print(\"Propagating to: \", zi * dz)\n",
" X.propagate_z(dz)\n",
" X.propagate_z(dz, inplace=True)\n",
" wf = np.abs(X.field_rspace[int(nt / 2), :, :]) ** 2\n",
" \n",
" popt_gaussian, ydata_fit, FWHM, roots = gaussian_fit(\n",
Expand Down
4 changes: 1 addition & 3 deletions pmd_beamphysics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .status import ParticleStatus
from .fields.fieldmesh import FieldMesh
from .readers import particle_paths
from .wavefront import Wavefront, propagate_z, focusing_element
from .wavefront import Wavefront
from .writers import pmd_init

from . import _version
Expand All @@ -17,6 +17,4 @@
"pmd_init",
"single_particle",
"Wavefront",
"propagate_z",
"focusing_element",
]
102 changes: 26 additions & 76 deletions pmd_beamphysics/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,26 +721,26 @@ def _ifft(self):
workers=workers,
)[*self._pad.ifft_slices]

def propagate_z(self, z_prop: float):
def propagate_z(self, z_prop: float, *, inplace: bool = False) -> Wavefront:
"""
Propagate this Wavefront in-place along Z in meters.
Parameters
----------
z_prop : float
Distance in meters.
inplace : bool, default=False
Perform the operation in-place on this wavefront object.
Returns
-------
np.ndarray
Propagated k-space data.
See Also
--------
`propagate_z`
For a version which returns a propagated copy of the wavefront,
instead of performing it in-place.
Wavefront
This object if `inplace=True` or a new copy if `inplace=False`.
"""
if not inplace:
wavefront = copy.copy(self)
return wavefront.propagate_z(z_prop, inplace=True)

z_prop = float(z_prop)
self._field_kspace = drift_propagator(
field_kspace=self.field_kspace,
Expand All @@ -754,7 +754,7 @@ def propagate_z(self, z_prop: float):
)
# Invalidate the real space data
self._field_rspace = None
return self._field_kspace
return self

@property
def wavelength(self) -> float:
Expand All @@ -770,7 +770,13 @@ def pad(self):
def ranges(self):
return self._ranges

def focusing_element(self, f_lens_x: float, f_lens_y: float):
def focusing_element(
self,
f_lens_x: float,
f_lens_y: float,
*,
inplace: bool = False,
) -> Wavefront:
"""
Apply thin lens focusing.
Expand All @@ -780,18 +786,18 @@ def focusing_element(self, f_lens_x: float, f_lens_y: float):
Focal length of the lens in x [m].
f_lens_y : float
Focal length of the lens in y [m].
inplace : bool, default=False
Perform the operation in-place on this wavefront object.
Returns
-------
np.ndarray
Focused r-space data.
See Also
--------
`focusing_element`
For a version which returns a focused copy of the wavefront,
instead of performing it in-place.
Wavefront
This object if `inplace=True` or a new copy if `inplace=False`.
"""
if not inplace:
wavefront = copy.copy(self)
return wavefront.focusing_element(f_lens_x, f_lens_y, inplace=True)

self._field_rspace = self.field_rspace * thin_lens_kernel(
wavelength=self.wavelength,
ranges=self._ranges,
Expand All @@ -801,7 +807,7 @@ def focusing_element(self, f_lens_x: float, f_lens_y: float):
)
# Invalidate the spectral data
self._field_kspace = None
return self._field_rspace
return self

def __deepcopy__(self, memo) -> Wavefront:
res = Wavefront.__new__(Wavefront)
Expand Down Expand Up @@ -891,59 +897,3 @@ def gaussian_pulse(
ranges=ranges,
pad=pad,
)


def propagate_z(wavefront: Wavefront, z_prop: float) -> Wavefront:
"""
Propagate a Wavefront along Z in meters and get a new `Wavefront` object.
Parameters
----------
wavefront : Wavefront
The Wavefront object to propagate.
z_prop : float
Distance in meters.
Returns
-------
Wavefront
Propagated Wavefront object.
See Also
--------
`Wavefront.propagate_z`
For an in-place version.
"""
wavefront = copy.copy(wavefront)
wavefront.propagate_z(z_prop)
return wavefront


def focusing_element(
wavefront: Wavefront, f_lens_x: float, f_lens_y: float
) -> Wavefront:
"""
Apply thin lens focusing to `wavefront` and get a new `Wavefront` object.
Parameters
----------
wavefront : Wavefront
The Wavefront object to focus.
f_lens_x : float
Focal length of the lens in x [m].
f_lens_y : float
Focal length of the lens in y [m].
Returns
-------
Wavefront
Focused Wavefront.
See Also
--------
`Wavefront.focusing_element`
For an in-place version.
"""
wavefront = copy.copy(wavefront)
wavefront.focusing_element(f_lens_x, f_lens_y)
return wavefront
10 changes: 5 additions & 5 deletions tests/test_wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from pmd_beamphysics import Wavefront, focusing_element, propagate_z
from pmd_beamphysics import Wavefront
from pmd_beamphysics.wavefront import (
WavefrontPadding,
get_num_fft_workers,
Expand Down Expand Up @@ -32,22 +32,22 @@ def wavefront() -> Wavefront:

def test_smoke_propagate_z_in_place(wavefront: Wavefront) -> None:
# Implicitly calculates the FFT:
wavefront.propagate_z(0.0)
wavefront.propagate_z(0.0, inplace=True)
# Use the property to calculate the inverse fft:
wavefront.field_rspace


def test_smoke_propagate_z(wavefront: Wavefront) -> None:
new = propagate_z(wavefront, 0.0)
new = wavefront.propagate_z(0.0, inplace=False)
assert new is not wavefront


def test_smoke_focusing_element_in_place(wavefront: Wavefront) -> None:
wavefront.focusing_element(1.0, 1.0)
wavefront.focusing_element(1.0, 1.0, inplace=True)


def test_smoke_focusing_element(wavefront: Wavefront) -> None:
new = focusing_element(wavefront, 1.0, 1.0)
new = wavefront.focusing_element(1.0, 1.0, inplace=False)
assert new is not wavefront


Expand Down

0 comments on commit db98ddb

Please sign in to comment.