Skip to content

Commit

Permalink
Merge pull request #46 from ChristopherMayes/slice_stats_dev
Browse files Browse the repository at this point in the history
Add methods for slice_statistics and plotting into ParticleGroup
  • Loading branch information
ChristopherMayes authored Aug 10, 2023
2 parents 5759911 + 59846b7 commit 3dab63e
Show file tree
Hide file tree
Showing 7 changed files with 385 additions and 149 deletions.
264 changes: 192 additions & 72 deletions docs/examples/particle_examples.ipynb

Large diffs are not rendered by default.

37 changes: 13 additions & 24 deletions docs/examples/plot_examples.ipynb

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions pmd_beamphysics/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@
'gamma': r'\gamma',
'theta': r'\theta',
'charge': 'Q',
'twiss_alpha_x' : 'Twiss\ '+r'\alpha_x',
'twiss_beta_x' : 'Twiss\ '+r'\beta_x',
'twiss_gamma_x' : 'Twiss\ '+r'\gamma_x',
'twiss_eta_x' : 'Twiss\ '+r'\eta_x',
'twiss_etap_x' : 'Twiss\ '+r"\eta'_x",
'twiss_emit_x' : 'Twiss\ '+r'\epsilon_{x}',
'twiss_norm_emit_x' : 'Twiss\ '+r'\epsilon_{n, x}',
'twiss_alpha_y' : 'Twiss\ '+r'\alpha_y',
'twiss_beta_y' : 'Twiss\ '+r'\beta_y',
'twiss_gamma_y' : 'Twiss\ '+r'\gamma_y',
'twiss_eta_y' : 'Twiss\ '+r'\eta_y',
'twiss_etap_y' : 'Twiss\ '+r"\eta'_y",
'twiss_emit_y' : 'Twiss\ '+r'\epsilon_{y}',
'twiss_norm_emit_y' : 'Twiss\ '+r'\epsilon_{n, y}',
# 'species_charge',
# 'weight',
'average_current': r'I_{av}',
Expand Down
74 changes: 48 additions & 26 deletions pmd_beamphysics/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pmd_beamphysics.readers import particle_array, particle_paths
from pmd_beamphysics.species import charge_of, mass_of

from pmd_beamphysics.statistics import norm_emit_calc, normalized_particle_coordinate, particle_amplitude, particle_twiss_dispersion, matched_particles, resample_particles
from pmd_beamphysics.statistics import norm_emit_calc, normalized_particle_coordinate, particle_amplitude, particle_twiss_dispersion, matched_particles, resample_particles, slice_statistics
from pmd_beamphysics.writers import write_pmd_bunch, pmd_init

from h5py import File
Expand Down Expand Up @@ -549,30 +549,30 @@ def Jy(self):

def delta(self, key):
"""Attribute (array) relative to its mean"""
return getattr(self, key) - self.avg(key)
return self[key] - self.avg(key)


# Statistical property functions

def min(self, key):
"""Minimum of any key"""
return np.min(getattr(self, key))
return np.min(self[key]) # was: getattr(self, key)
def max(self, key):
"""Maximum of any key"""
return np.max(getattr(self, key))
return np.max(self[key])
def ptp(self, key):
"""Peak-to-Peak = max - min of any key"""
return np.ptp(getattr(self, key))
return np.ptp(self[key])

def avg(self, key):
"""Statistical average"""
dat = getattr(self, key) # equivalent to self.key for accessing properties above
dat = self[key] # equivalent to self.key for accessing properties above
if np.isscalar(dat):
return dat
return np.average(dat, weights=self.weight)
def std(self, key):
"""Standard deviation (actually sample)"""
dat = getattr(self, key)
dat = self[key]
if np.isscalar(dat):
return 0
avg_dat = self.avg(key)
Expand All @@ -586,7 +586,7 @@ def cov(self, *keys):
P.cov('x', 'px', 'y', 'py')
"""
dats = np.array([ getattr(self, key) for key in keys ])
dats = np.array([ self[key] for key in keys ])
return np.cov(dats, aweights=self.weight)

def histogramdd(self, *keys, bins=10, range=None):
Expand All @@ -604,9 +604,7 @@ def histogramdd(self, *keys, bins=10, range=None):
H, edges = np.histogramdd(np.array([self[k] for k in list(keys)]).T, weights=self.weight, bins=bins, range=range)

return H, edges





# Beam statistics
@property
Expand Down Expand Up @@ -900,27 +898,51 @@ def plot(self, key1='x', key2=None,
if return_figure:
return fig

def slice_plot(self, key='sigma_x',



def slice_statistics(self, *keys,
n_slice=100,
slice_key=None):
"""
Slice statistics
"""

if slice_key is None:
if self.in_t_coordinates:
slice_key = 'z'

else:
slice_key = 't'

if slice_key in ('t', 'delta_t'):
density_name = 'current'
else:
density_name = 'density'

keys = set(keys)
keys.add('mean_'+slice_key)
keys.add('ptp_'+slice_key)
keys.add('charge')
slice_dat = slice_statistics(self, n_slice=n_slice, slice_key=slice_key,
keys=keys)

slice_dat[density_name] = slice_dat['charge']/ slice_dat['ptp_'+slice_key]

return slice_dat

def slice_plot(self, *keys,
n_slice=100,
slice_key=None,
tex=True,
nice=True,
return_figure=False,
xlim=None,
ylim=None,
**kwargs):
"""
Slice statistics plot.
**kwargs):

"""

if not slice_key:
if self.in_t_coordinates:
slice_key = 'z'
else:
slice_key = 't'

fig = slice_plot(self, stat_key=key,
fig = slice_plot(self, *keys,
n_slice=n_slice,
slice_key=slice_key,
tex=tex,
Expand All @@ -930,7 +952,7 @@ def slice_plot(self, key='sigma_x',
**kwargs)

if return_figure:
return fig
return fig


# New constructors
Expand Down Expand Up @@ -1129,7 +1151,7 @@ def split_particles(particle_group, n_chunks = 100, key='z'):
"""

# Sorting
zlist = getattr(particle_group, key)
zlist = particle_group[key]
iz = np.argsort(zlist)

# Split particles into chunks
Expand Down
93 changes: 69 additions & 24 deletions pmd_beamphysics/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def plt_histogram(a, weights=None, bins=40):


def slice_plot(particle_group,
stat_key='sigma_x',
*keys,
n_slice=40,
slice_key='z',
slice_key=None,
xlim=None,
ylim=None,
tex=True,
Expand All @@ -54,63 +54,108 @@ def slice_plot(particle_group,
particle_group: ParticleGroup
The object to plot
stat_key: str, default = 'sigma_x'
Key to calculate the statistics
keys: iterable of str
Keys to calculate the statistics, e.g. `sigma_x`.
n_slice: int, default = 40
Number of slices
slice_key: str, default = 'z'
Should be 'z' or 't'
slice_key: str, default = None
The dimension to slice in. This is typically `t` or `z`.
`delta_t`, etc. are also allowed.
If None, `t` or `z` will automatically be determined.
ylim: tuple, default = None
Manual setting of the y-axis limits.
tex: bool, defaul = True
tex: bool, default = True
Use TEX for labels
Returns
-------
fig: matplotlib.figure.Figure
"""

# Allow a single key
#if isinstance(keys, str):
#
# keys = (keys, )

if slice_key is None:
if particle_group.in_t_coordinates:
slice_key = 'z'
else:
slice_key = 't'

# Special case for delta_
if slice_key.startswith('delta_'):
slice_key = slice_key[6:]
has_delta_prefix = True
else:
has_delta_prefix = False

# Get all data
x_key = 'mean_'+slice_key
y_key = stat_key
slice_dat = slice_statistics(particle_group, n_slice=n_slice, slice_key=slice_key,
keys=[x_key, y_key, 'ptp_'+slice_key, 'charge'])


slice_dat = particle_group.slice_statistics(*keys, n_slice=n_slice, slice_key=slice_key)
slice_dat['density'] = slice_dat['charge']/ slice_dat['ptp_'+slice_key]
y2_key = 'density'

# X-axis
x = slice_dat['mean_'+slice_key]
if has_delta_prefix:
x -= particle_group['mean_'+slice_key]
slice_key = 'delta_'+slice_key # restore

x, f1, p1, xmin, xmax = plottable_array(x, nice=nice, lim=xlim)
ux = p1+str(particle_group.units(slice_key))

# Y-axis

# Units check
ulist = [particle_group.units(k).unitSymbol for k in keys]
uy = ulist[0]
if not all([u==uy for u in ulist] ):
raise ValueError(f'Incompatible units: {ulist}')

ymin = max([slice_dat[k].min() for k in keys])
ymax = max([slice_dat[k].max() for k in keys])

_, f2, p2, ymin, ymax = plottable_array(np.array([ymin, ymax]), nice=nice, lim=ylim)
uy = p2 + uy

# Form Figure
fig, ax = plt.subplots(**kwargs)

# Get nice arrays
x, f1, prex, xmin, xmax = plottable_array(slice_dat[x_key], nice=nice, lim=xlim)
y, f2, prey, ymin, ymax = plottable_array(slice_dat[y_key], nice=nice, lim=ylim)
# Main curves
if len(keys) == 1:
color = 'black'
else:
color = None

for k in keys:
label = mathlabel(k, units=uy, tex=tex)
ax.plot(x, slice_dat[k]/f2, label=label, color=color)
if len(keys) > 1:
ax.legend()

# Density on r.h.s
y2, _, prey2, _, _ = plottable_array(slice_dat[y2_key], nice=nice, lim=None)

x_units = f'{prex}{particle_group.units(x_key)}'
y_units = f'{prey}{particle_group.units(y_key)}'

# Convert to Amps if possible
y2_units = f'C/{particle_group.units(x_key)}'
if y2_units == 'C/s':
y2_units = 'A'
y2_units = prey2+y2_units

# Labels
labelx = mathlabel(slice_key, units=x_units, tex=tex)
labely = mathlabel(y_key, units=y_units, tex=tex)
labelx = mathlabel(slice_key, units=ux, tex=tex)
labely = mathlabel(*keys, units=uy, tex=tex)
labely2 = mathlabel(y2_key, units=y2_units, tex=tex)

ax.set_xlabel(labelx)
ax.set_ylabel(labely)

# Main plot
ax.plot(x, y, color = 'black')


# rhs plot
ax2 = ax.twinx()
Expand Down
40 changes: 38 additions & 2 deletions pmd_beamphysics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def normalized_particle_coordinate(particle_group, key, twiss=None, mass_normali
# ---------------
# Other utilities

def slice_statistics(particle_group, keys=['mean_z'], n_slice=40, slice_key='z'):
def slice_statistics(particle_group, keys=['mean_z'], n_slice=40, slice_key=None):
"""
Slices a particle group into n slices and returns statistics from each sliced defined in keys.
Expand All @@ -413,12 +413,48 @@ def slice_statistics(particle_group, keys=['mean_z'], n_slice=40, slice_key='z'
Any key can be used to slice on.
"""

if slice_key is None:
if particle_group.in_t_coordinates:
slice_key = 'z'
else:
slice_key = 't'

sdat = {}
twiss_planes = set()
twiss = {}

normal_keys = set()

for k in keys:
sdat[k] = np.empty(n_slice)
if k.startswith('twiss'):
if k == 'twiss' or k == 'twiss_xy':
twiss_planes.add('x')
twiss_planes.add('y')
else:
plane = k[-1] #
assert plane in ('x', 'y')
twiss_planes.add(plane)
else:
normal_keys.add(k)

twiss_plane = ''.join(twiss_planes) # flatten
assert twiss_plane in ('x', 'y', 'xy', 'yx', '')

for i, pg in enumerate(particle_group.split(n_slice, key=slice_key)):
for k in keys:
for k in normal_keys:
sdat[k][i] = pg[k]

# Handle twiss
if twiss_plane:
twiss = pg.twiss(plane=twiss_plane)
for k in twiss:
full_key = f'twiss_{k}'
if full_key not in sdat:
sdat[full_key] = np.empty(n_slice)
sdat[full_key][i] = twiss[k]


return sdat

Expand Down
12 changes: 11 additions & 1 deletion pmd_beamphysics/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,16 @@ def plottable_array(x, nice=True, lim=None):
PARTICLEGROUP_UNITS[f'E{component}'] = unit('V/m')
PARTICLEGROUP_UNITS[f'B{component}'] = unit('T')

# Twiss
for plane in ('x', 'y'):
for k in ('alpha', 'etap'):
PARTICLEGROUP_UNITS[f'twiss_{k}_{plane}'] = unit('1')
for k in ('beta', 'eta', 'emit', 'norm_emit'):
PARTICLEGROUP_UNITS[f'twiss_{k}_{plane}'] = unit('m')
for k in ('gamma', ):
PARTICLEGROUP_UNITS[f'twiss_{k}_{plane}'] = divide_units(unit('1'), unit('m') )





Expand All @@ -479,7 +489,7 @@ def pg_units(key):
for prefix in ['sigma_', 'mean_', 'min_', 'max_', 'ptp_', 'delta_']:
if key.startswith(prefix):
nkey = key[len(prefix):]
return PARTICLEGROUP_UNITS[nkey]
return pg_units(nkey)

if key.startswith('cov_'):
subkeys = key.strip('cov_').split('__')
Expand Down

0 comments on commit 3dab63e

Please sign in to comment.