diff --git a/pmd_beamphysics/particles.py b/pmd_beamphysics/particles.py index f0e3d66..73ac345 100644 --- a/pmd_beamphysics/particles.py +++ b/pmd_beamphysics/particles.py @@ -1,4 +1,4 @@ -from pmd_beamphysics.units import dimension, dimension_name, SI_symbol, pg_units, c_light, parse_bunching_str +from pmd_beamphysics.units import pg_units, c_light, parse_bunching_str from pmd_beamphysics.interfaces.astra import write_astra import pmd_beamphysics.interfaces.bmad as bmad @@ -678,7 +678,7 @@ def average_current(self): return self.charge / dt def bunching(self, wavelength): - """ + r""" Calculate the normalized bunching parameter, which is the magnitude of the complex sum of weighted exponentials at a given point. @@ -910,7 +910,8 @@ def write(self, h5, name=None): if isinstance(h5, str): fname = os.path.expandvars(h5) g = File(fname, 'w') - pmd_init(g, basePath='/', particlesPath='.' ) + pmd_init(g, basePath='/', particlesPath='particles' ) + g = g.create_group('particles') else: g = h5 @@ -1158,6 +1159,14 @@ def load_bunch_data(h5): """ Load particles into structured numpy array. """ + + # Legacy-style particles with no species + if 'position' not in h5: + species = list(h5) + if len(species) != 1: + raise NotImplementedError(f"multiple species in particle paths: {species}") + h5 = h5[species[0]] + n = len(h5['position/x']) attrs = dict(h5.attrs) @@ -1257,9 +1266,10 @@ def split_particles(particle_group, n_chunks = 100, key='z'): # Split particles into chunks plist = [] for chunk in np.array_split(iz, n_chunks): + # Prepare data data = {} - #keys = ['x', 'px', 'y', 'py', 'z', 'pz', 't', 'status', 'weight'] + for k in particle_group._settable_array_keys: data[k] = getattr(particle_group, k)[chunk] # These should be scalars diff --git a/pmd_beamphysics/writers.py b/pmd_beamphysics/writers.py index 5caaab6..0c2ff73 100644 --- a/pmd_beamphysics/writers.py +++ b/pmd_beamphysics/writers.py @@ -57,9 +57,13 @@ def write_pmd_bunch(h5, data, name=None): g = h5.create_group(name) else: g = h5 + + # Write into species group + species = data['species'] + g = g.create_group(species) # Attributes - g.attrs['speciesType'] = fstr( data['species'] ) + g.attrs['speciesType'] = fstr( species ) g.attrs['numParticles'] = data['n_particle'] g.attrs['totalCharge'] = data['charge'] g.attrs['chargeUnitSI'] = 1.0