Skip to content

Commit

Permalink
ENH: 'nice' axis labels + extents
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-lauer committed Sep 13, 2024
1 parent d84a605 commit 4429015
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions pmd_beamphysics/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .metadata import PolarizationDirection, WavefrontMetadata
from . import writers
from .units import known_unit
from .units import known_unit, nice_array

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -411,7 +411,7 @@ def calculate_k0(wavelength: float) -> float:

def conversion_coeffs(wavelength: float, dim: int) -> Tuple[float, ...]:
"""
Conversion coefficients to (eV, radians, radians, ...).
Conversion coefficients to (radians, radians, eV).
Theta-x, theta-y, omega.
"""
Expand Down Expand Up @@ -1256,13 +1256,22 @@ def plot(
if rspace:
data = self.rmesh
labels = get_rspace_labels(self.axis_labels, axis_indices)
domain = [self.rspace_domain[idx] for idx in axis_indices]
units = ["m", "m"]
else:
data = self.kmesh
labels = get_kspace_labels(self.axis_labels, axis_indices)
domain = [self.kspace_domain[idx] for idx in axis_indices]
units = ["rad", "rad"]

(domain_x, domain_y), _scale, unit_prefix = nice_array(np.vstack(domain))
extent = (domain_x[0], domain_x[-1], domain_y[-1], domain_y[0])

if transpose:
data = data.T
labels = tuple(reversed(labels))
extent = tuple(reversed(extent))
units = tuple(reversed(units))

sum_axis = tuple(axis for axis in range(data.ndim) if axis not in axis_indices)
plane_label = " ".join(labels)
Expand All @@ -1286,15 +1295,15 @@ def plot(

def plot(dat, title: str):
ax = remaining_axes.pop(0)
img = ax.imshow(np.mean(dat, axis=sum_axis), cmap=cmap)
img = ax.imshow(np.mean(dat, axis=sum_axis), cmap=cmap, extent=extent)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
if not ax.get_title():
ax.set_title(title)
ax.set_xlabel(f"${labels[0]}$")
ax.set_ylabel(f"${labels[1]}$")
ax.set_xlabel(f"${labels[0]}$ [{unit_prefix}{units[0]}]")
ax.set_ylabel(f"${labels[1]}$ [{unit_prefix}{units[1]}]")
images.append(img)
return img

Expand Down

0 comments on commit 4429015

Please sign in to comment.