Skip to content

Commit

Permalink
Handle legacy gridDataOrder, use axisLabels instead in FieldMesh
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherMayes committed May 30, 2024
1 parent 383f823 commit 4057fbb
Show file tree
Hide file tree
Showing 10 changed files with 481 additions and 808 deletions.
Binary file modified docs/examples/data/rfgun.h5
Binary file not shown.
Binary file modified docs/examples/data/rfgun_rectangular.h5
Binary file not shown.
Binary file modified docs/examples/data/solenoid.h5
Binary file not shown.
372 changes: 101 additions & 271 deletions docs/examples/fields/field_examples.ipynb

Large diffs are not rendered by default.

379 changes: 93 additions & 286 deletions docs/examples/fields/field_expansion.ipynb

Large diffs are not rendered by default.

454 changes: 232 additions & 222 deletions docs/examples/fields/field_tracking.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pmd_beamphysics/fields/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# Analysis

def accelerating_voltage_and_phase(z, Ez, frequency):
"""
r"""
Computes the accelerating voltage and phase for a v=c positively charged particle in an accelerating cavity field.
Z = \int Ez * e^{-i k z} dz
Expand Down Expand Up @@ -57,7 +57,7 @@ def track_field_1d(z,
debug=False,
max_step=None,
):
"""
r"""
Tracks a particle in a 1d complex electric field Ez, oscillating as Ez * exp(-i omega t)
Uses scipy.integrate.solve_ivp to track the particle.
Expand Down Expand Up @@ -182,7 +182,7 @@ def track_field_1df(Ez_f,
max_step=None,
method='RK23'
):
"""
r"""
Similar to track_field_1d, execpt uses a function Ez_f
Tracks a particle in a 1d electric field Ez(z, t)
Expand Down Expand Up @@ -342,7 +342,7 @@ def autophase_field(field_mesh, pz0=0, scale=1, species='electron', tol=1e-9, ve

# Function for use in brent
def phase_f(phase_deg):
zf, pf = track_field_1d(z,
_, pf, _ = track_field_1d(z,
Ez,
frequency=frequency,
z0=zmin,
Expand Down
27 changes: 20 additions & 7 deletions pmd_beamphysics/fields/fieldmesh.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pmd_beamphysics.units import pg_units

from pmd_beamphysics.readers import component_data, expected_record_unit_dimension, field_record_components, field_paths, component_from_alias, load_field_attrs, component_alias
from pmd_beamphysics.readers import component_data, expected_record_unit_dimension, field_record_components, field_paths, component_from_alias, load_field_attrs, component_alias, is_legacy_fortran_data_ordering, decode_attr

from pmd_beamphysics.writers import write_pmd_field, pmd_field_init

Expand Down Expand Up @@ -796,7 +796,7 @@ def load_field_data_h5(h5, verbose=True):
data dict
"""
data = {'components':{}}

# Load attributes
attrs, other = load_field_attrs(h5.attrs, verbose=verbose)
attrs.update(other)
Expand All @@ -806,15 +806,28 @@ def load_field_data_h5(h5, verbose=True):
for g, comps in field_record_components.items():
if g not in h5:
continue


# Extract axis labels for data transposing
axis_labels = tuple([decode_attr(a) for a in h5.attrs['axisLabels']])

# Get the full openPMD unitDimension
required_dim = expected_record_unit_dimension[g]


# Filter out only the components that we have
comps = [comp for comp in comps if comp in h5[g]]

for comp in comps:
if comp not in h5[g]:
continue
name = g+'/'+comp
cdat = component_data(h5[name])

# Handle legacy format
if is_legacy_fortran_data_ordering(h5[name].attrs):
if 'r' in comps:
axis_labels = ('z', 'theta', 'r')
else:
axis_labels = ('z', 'y', 'x')


cdat = component_data(h5[name], axis_labels=axis_labels)

# Check dimensions
dim = h5[name].attrs['unitDimension']
Expand Down
39 changes: 25 additions & 14 deletions pmd_beamphysics/readers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .units import dimension, dimension_name, SI_symbol, c_light, e_charge
from .tools import decode_attrs, decode_attr


import numpy as np

import warnings

#-----------------------------------------
# General Utilities

Expand Down Expand Up @@ -154,18 +155,27 @@ def component_unit_dimension(h5):
Return the unit dimension tuple
"""
return tuple(h5.attrs['unitDimension'])


def is_legacy_fortran_data_ordering(component_data_attrs):
if 'gridDataOrder' in component_data_attrs:
warnings.warn("Legacy gridDataOrder in component. Please remove and use axisLabels at the group level.")
if decode_attr(component_data_attrs['gridDataOrder'])=='F':
return True
return False

def component_data(h5, slice = slice(None), unit_factor=1):
def component_data(h5, slice = slice(None), unit_factor=1, axis_labels=None):
"""
Returns a numpy array from an h5 component.
Determines wheter a component has constant data, or array data, and returns that.
An optional slice allows parts of the array to be retrieved.
This checks for a gridDataOrder attribute: F or C. If F, the np array is transposed.
This checks for legacy gridDataOrder attribute: F or C. If F, the np array is transposed.
Unit factor is an additional factor to convert from SI units to output units.
"""

Expand All @@ -182,22 +192,24 @@ def component_data(h5, slice = slice(None), unit_factor=1):
if is_constant_component(h5):
dat = np.full(h5.attrs['shape'], h5.attrs['value'])[slice]

# Check multidimensional for data ordering
elif len(h5.shape) > 1:

# Check for Fortran order
if 'gridDataOrder' in h5.attrs and decode_attr(h5.attrs['gridDataOrder'])=='F':

# Check multidimensional for data ordering, convert to 'x', 'y', 'z' order
elif len(h5.shape) > 1:
if axis_labels is None:
raise ValueError('axis_labels required for multidimensional arrays')

# Reorder to x, y, z
if axis_labels in [('z', 'y', 'x'), ('z', 'theta', 'r')]:
if isinstance(slice, tuple):
# Need to transpose the slice ordering
slice = slice[::-1]

# Retrieve dataset and transpose for C order
dat = h5[slice]
dat = np.transpose(dat)
else:
# C-order
elif axis_labels in [('x', 'y', 'z'), ('r', 'theta', 'z')]:
dat = h5[slice]
else:
raise NotImplementedError(f'axis_labels: {axis_labels}')

# 1-D array
else:
Expand Down Expand Up @@ -255,8 +267,7 @@ def particle_array(h5, component, slice=slice(None), include_offset=True):
if include_offset and ocomponent in h5 :
offset = component_data(h5[ocomponent], slice = slice, unit_factor=unit_factor)
dat += offset



return dat


Expand Down Expand Up @@ -330,7 +341,7 @@ def component_str(particle_group, name):
'eleAnchorPt', 'gridGeometry', 'axisLabels',
# reals and ints
'gridLowerBound', 'gridOriginOffset', 'gridSpacing', 'gridSize', 'harmonic'
]
]

# Dict with options
optional_field_attrs = {
Expand Down
10 changes: 6 additions & 4 deletions pmd_beamphysics/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ def write_pmd_field(h5, data, name=None):
# Validate attrs
attrs, other = load_field_attrs(data['attrs'])

# Encode and write required and optional
# Encode for writing
attrs = encode_attrs(attrs)

# Write attributes
for k, v in attrs.items():
g.attrs[k] = v

Expand All @@ -118,7 +120,9 @@ def write_pmd_field(h5, data, name=None):
val = val.astype(complex)

# Write
g2 = write_component_data(g, key, val, unit=u)
g2 = write_component_data(g, key, val, unit=u)




def write_component_data(h5, name, data, unit=None):
Expand All @@ -139,8 +143,6 @@ def write_component_data(h5, name, data, unit=None):
else:
h5[name] = data
g = h5[name]
if len(data.shape) > 1:
g.attrs['gridDataOrder'] = fstr('C') # C order for numpy/h5py

if unit:
g.attrs['unitSI'] = unit.unitSI
Expand Down

0 comments on commit 4057fbb

Please sign in to comment.