Skip to content

Commit

Permalink
Better scaling with plottable_array
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherMayes committed Aug 8, 2023
1 parent ed9bbc9 commit 3969fff
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 218 deletions.
164 changes: 25 additions & 139 deletions docs/examples/plot_examples.ipynb

Large diffs are not rendered by default.

22 changes: 16 additions & 6 deletions pmd_beamphysics/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,8 @@ def plot(self, key1='x', key2=None,
xlim=None,
ylim=None,
return_figure=False,

tex=True, **kwargs):
tex=True, nice=True,
**kwargs):
"""
1d or 2d density plot.
Expand Down Expand Up @@ -855,14 +855,17 @@ def plot(self, key1='x', key2=None,
Number of bins. If None, this will use a heuristic: bins = sqrt(n_particle/4)
xlim: tuple, default = None
Manual setting of the x-axis limits.
Manual setting of the x-axis limits. Note that these are in raw, unscaled units.
ylim: tuple, default = None
Manual setting of the y-axis limits.
Manual setting of the y-axis limits. Note that these are in raw, unscaled units.
tex: bool, defaul = True
tex: bool, default = True
Use TEX for labels
nice: bool, default = True
Scale to nice units
return_figure: bool, default = False
If true, return a matplotlib.figure.Figure object
Expand All @@ -881,13 +884,16 @@ def plot(self, key1='x', key2=None,
fig = density_plot(self, key=key1,
bins=bins,
xlim=xlim,
tex=tex, **kwargs)
tex=tex,
nice=nice,
**kwargs)
else:
fig = marginal_plot(self, key1=key1, key2=key2,
bins=bins,
xlim=xlim,
ylim=ylim,
tex=tex,
nice=nice,
**kwargs)

if return_figure:
Expand All @@ -897,7 +903,9 @@ def slice_plot(self, key='sigma_x',
n_slice=100,
slice_key=None,
tex=True,
nice=True,
return_figure=False,
xlim=None,
ylim=None,
**kwargs):
"""
Expand All @@ -915,6 +923,8 @@ def slice_plot(self, key='sigma_x',
n_slice=n_slice,
slice_key=slice_key,
tex=tex,
nice=nice,
xlim=xlim,
ylim=ylim,
**kwargs)

Expand Down
97 changes: 38 additions & 59 deletions pmd_beamphysics/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from pmd_beamphysics.units import nice_array, nice_scale_prefix
from pmd_beamphysics.units import nice_array, plottable_array, nice_scale_prefix
from pmd_beamphysics.labels import mathlabel


Expand Down Expand Up @@ -41,8 +41,10 @@ def slice_plot(particle_group,
stat_key='sigma_x',
n_slice=40,
slice_key='z',
xlim=None,
ylim=None,
tex=True,
nice=True,
**kwargs):
"""
Complete slice plotting routine. Will plot the density of the slice key on the right axis.
Expand Down Expand Up @@ -85,9 +87,10 @@ def slice_plot(particle_group,
fig, ax = plt.subplots(**kwargs)

# Get nice arrays
x, _, prex = nice_array(slice_dat[x_key])
y, yfactor, prey = nice_array(slice_dat[y_key])
y2, _, prey2 = nice_array(slice_dat[y2_key])
x, f1, prex, xmin, xmax = plottable_array(slice_dat[x_key], nice=nice, lim=xlim)
y, f2, prey, ymin, ymax = plottable_array(slice_dat[y_key], nice=nice, lim=ylim)
# Density on r.h.s
y2, _, prey2, _, _ = plottable_array(slice_dat[y2_key], nice=nice, lim=None)

x_units = f'{prex}{particle_group.units(x_key)}'
y_units = f'{prey}{particle_group.units(y_key)}'
Expand Down Expand Up @@ -115,18 +118,12 @@ def slice_plot(particle_group,
ax2.fill_between(x, 0, y2, color='black', alpha = 0.2)
ax2.set_ylim(0, None)

# Limits
# Actual plot limits, considering scaling
if xlim:
ax.set_xlim( xmin/f1, xmax/f1)
if ylim:
ymin = ylim[0]
ymax = ylim[1]
# Handle None and scaling
if ymin is not None:
ymin = ymin/yfactor
if ymax is not None:
ymax = ymax/yfactor
new_ylim = (ymin, ymax)
ax.set_ylim(new_ylim)

ax.set_ylim( ymin/f2, ymax/f2)

return fig


Expand All @@ -135,7 +132,9 @@ def density_plot(particle_group, key='x',
bins=None,
*,
xlim=None,
tex=True, **kwargs):
tex=True,
nice=True,
**kwargs):
"""
1D density plot. Also see: marginal_plot
Expand All @@ -150,7 +149,7 @@ def density_plot(particle_group, key='x',
bins = int(n/100)

# Scale to nice units and get the factor, unit prefix
x, f1, p1 = nice_array(particle_group[key])
x, f1, p1, xmin, xmax = plottable_array(particle_group[key], nice=nice, lim=xlim)
w = particle_group['weight']
u1 = particle_group.units(key).unitSymbol
ux = p1+u1
Expand All @@ -177,15 +176,7 @@ def density_plot(particle_group, key='x',

# Limits
if xlim:
xmin = xlim[0]
xmax = xlim[1]
# Handle None and scaling
if xmin is not None:
xmin = xmin/f1
if xmax is not None:
xmax = xmax/f1
new_xlim = (xmin, xmax)
ax.set_xlim(new_xlim)
ax.set_xlim(xmin/f1, xmax/f1)

return fig

Expand All @@ -195,6 +186,7 @@ def marginal_plot(particle_group, key1='t', key2='p',
xlim=None,
ylim=None,
tex=True,
nice=True,
**kwargs):
"""
Density plot and projections
Expand Down Expand Up @@ -224,25 +216,30 @@ def marginal_plot(particle_group, key1='t', key2='p',
ylim: tuple, default = None
Manual setting of the y-axis limits.
tex: bool, defaul = True
tex: bool, default = True
Use TEX for labels
nice: bool, default = True
Returns
-------
fig: matplotlib.figure.Figure
"""

"""
if not bins:
n = len(particle_group)
bins = int(np.sqrt(n/4) )

# Scale to nice units and get the factor, unit prefix
x, f1, p1 = nice_array(particle_group[key1])
y, f2, p2 = nice_array(particle_group[key2])

x = particle_group[key1]
y = particle_group[key2]

# Form nice arrays
x, f1, p1, xmin, xmax = plottable_array(x, nice=nice, lim=xlim)
y, f2, p2, ymin, ymax = plottable_array(y, nice=nice, lim=ylim)

w = particle_group['weight']

u1 = particle_group.units(key1).unitSymbol
Expand Down Expand Up @@ -272,8 +269,6 @@ def marginal_plot(particle_group, key1='t', key2='p',
#extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
#ax_joint.imshow(H.T, cmap=cmap, vmin=1e-16, origin='lower', extent=extent, aspect='auto')



# Top histogram
# Old method:
#dx = x.ptp()/bins
Expand Down Expand Up @@ -310,32 +305,16 @@ def marginal_plot(particle_group, key1='t', key2='p',
ax_joint.set_xlabel(labelx)
ax_joint.set_ylabel(labely)

# Limits
# Actual plot limits, considering scaling
if xlim:
xmin = xlim[0]
xmax = xlim[1]
# Handle None and scaling
if xmin is not None:
xmin = xmin/f1
if xmax is not None:
xmax = xmax/f1
new_xlim = (xmin, xmax)
ax_joint.set_xlim(new_xlim)
ax_marg_x.set_xlim(new_xlim)

ax_joint.set_xlim( xmin/f1, xmax/f1)
ax_marg_x.set_xlim(xmin/f1, xmax/f1)

if ylim:
ymin = ylim[0]
ymax = ylim[1]
# Handle None and scaling
if ymin is not None:
ymin = ymin/f2
if ymax is not None:
ymax = ymax/f2
new_ylim = (ymin, ymax)
ax_joint.set_ylim(new_ylim)
ax_marg_y.set_ylim(new_ylim)
ax_joint.set_ylim( ymin/f2, ymax/f2)
ax_marg_y.set_ylim(ymin/f2, ymax/f2)

return fig
return fig


def density_and_slice_plot(particle_group, key1='t', key2='p', stat_keys=['norm_emit_x', 'norm_emit_y'], bins=100, n_slice=30, tex=True):
Expand All @@ -349,8 +328,8 @@ def density_and_slice_plot(particle_group, key1='t', key2='p', stat_keys=['norm_
"""

# Scale to nice units and get the factor, unit prefix
x, f1, p1 = nice_array(particle_group[key1])
y, f2, p2 = nice_array(particle_group[key2])
x, f1, p1, xmin, xmax = plottable_array(particle_group[key1])
y, f2, p2, ymin, ymax = plottable_array(particle_group[key2])
w = particle_group['weight']

u1 = particle_group.units(key1).unitSymbol
Expand Down
71 changes: 57 additions & 14 deletions pmd_beamphysics/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def sqrt_unit(u):

return pmd_unit(unitSymbol=symbol, unitSI=unitSI, unitDimension=dim)

# length mass time current temperature mol luminous
DIMENSION = {
'1' : (0,0,0,0,0,0,0),
# Base units
Expand All @@ -187,14 +187,14 @@ def sqrt_unit(u):
'mol' : (0,0,0,0,0,1,0),
'luminous' : (0,0,0,0,0,0,1),
#
'charge' : (0,0,1,1,0,0,0),
'electric_field' : (1,1,-3,-1,0,0,0),
'charge' : (0,0,1,1,0,0,0),
'electric_field' : (1,1,-3,-1,0,0,0),
'electric_potential' : (1,2,-3,-1,0,0,0),
'magnetic_field' : (0,1,-2,-1,0,0,0),
'velocity' : (1,0,-1,0,0,0,0),
'energy' : (2,1,-2,0,0,0,0),
'momentum' : (1,1,-1,0,0,0,0)
}
'magnetic_field' : (0,1,-2,-1,0,0,0),
'velocity' : (1,0,-1,0,0,0,0),
'energy' : (2,1,-2,0,0,0,0),
'momentum' : (1,1,-1,0,0,0,0)
}
# Inverse
DIMENSION_NAME = {v: k for k, v in DIMENSION.items()}

Expand Down Expand Up @@ -227,7 +227,7 @@ def dimension_name(dim_array):
# Inverse
SI_name = {v: k for k, v in SI_symbol.items()}


# length mass time current temperature mol luminous
known_unit = {
'1' : pmd_unit('', 1, '1'),
'degree' : pmd_unit('degree', np.pi/180, '1'),
Expand All @@ -250,11 +250,14 @@ def dimension_name(dim_array):
'J' : pmd_unit('J', 1, 'energy'),
'eV/c' : pmd_unit('eV/c', e_charge/c_light, 'momentum'),
'eV/m' : pmd_unit('eV/m', e_charge, (1, 1, -2, 0, 0, 0, 0)),
'W/m^2' : pmd_unit('W/m^2', 1, (1, 0, -3, 0, 0, 0, 0)),
'W' : pmd_unit('W', 1, (1, 2, -3, 0, 0, 0, 0)),
'W' : pmd_unit('W', 1, (2, 1, -3, 0, 0, 0, 0)),
'W/m^2' : pmd_unit('W/m^2', 1, (0, 1, -3, 0, 0, 0, 0)),
'W/rad^2' : pmd_unit('W/rad^2', 1, (2, 1, -3, 0, 0, 0, 0)),
'T' : pmd_unit('T', 1, 'magnetic_field')
}



def unit(symbol):
"""
Returns a pmd_unit from a known symbol.
Expand Down Expand Up @@ -367,9 +370,7 @@ def nice_array(a):
Returns:
(array([200., 300.]), 1e-12, 'p')
"""
#print('a', a.tolist())

"""
if np.isscalar(a):
x = a
elif len(a) == 1:
Expand All @@ -383,6 +384,48 @@ def nice_array(a):
return a/fac, fac, prefix


def plottable_array(x, nice=True, lim=None):
"""
Similar to nice_array, but also considers limits for plotting
Parameters
----------
x: array-like
nice: bool, default = True
Scale array by some nice factor.
xlim: tuple, default = None
Returns
-------
scaled_array: np.ndarray
factor: float
prefix: str
xmin: float
xmax : float
"""
if lim is not None:
if lim[0] is None:
xmin = x.min()
else:
xmin = lim[0]
if lim[1] is None:
xmax = x.max()
else:
xmax = lim[1]

else:
xmin = x.min()
xmax = x.max()

if nice:
_, factor, p1 = nice_array([xmin, xmax])
else:
factor, p1 = 1, ''

return x/factor, factor, p1, xmin, xmax




# -------------------------
Expand Down

0 comments on commit 3969fff

Please sign in to comment.