diff --git a/pmd_beamphysics/interfaces/genesis.py b/pmd_beamphysics/interfaces/genesis.py index d566fdf..447d1c8 100644 --- a/pmd_beamphysics/interfaces/genesis.py +++ b/pmd_beamphysics/interfaces/genesis.py @@ -398,7 +398,8 @@ def write_genesis4_distribution(particle_group, If particles are at different z, they will be drifted to the same z, because the output should have different times. - If any of the weights are different, the bunch will be resampled. + If any of the weights are different, the bunch will be resampled + to have equal weights. Note that this can be very slow for a large number of particles. """ @@ -423,7 +424,7 @@ def write_genesis4_distribution(particle_group, n = len(P) if verbose: print(f'Resampling {n} weighted particles') - P = P.resample(n) + P = P.resample(n, equal_weights=True) for k in ['x', 'xp', 'y', 'yp', 't']: h5[k] = P[k] diff --git a/pmd_beamphysics/particles.py b/pmd_beamphysics/particles.py index e036a5f..f0e3d66 100644 --- a/pmd_beamphysics/particles.py +++ b/pmd_beamphysics/particles.py @@ -1064,8 +1064,8 @@ def copy(self): return deepcopy(self) @functools.wraps(resample_particles) - def resample(self, n=0): - data = resample_particles(self, n) + def resample(self, n=0, equal_weights=False): + data = resample_particles(self, n, equal_weights=equal_weights) return ParticleGroup(data=data) # Internal sorting diff --git a/pmd_beamphysics/statistics.py b/pmd_beamphysics/statistics.py index fb8bc6f..0be5504 100644 --- a/pmd_beamphysics/statistics.py +++ b/pmd_beamphysics/statistics.py @@ -460,7 +460,7 @@ def slice_statistics(particle_group, keys=['mean_z'], n_slice=40, slice_key=Non -def resample_particles(particle_group, n=0): +def resample_particles(particle_group, n=0, equal_weights=False): """ Resamples a ParticleGroup randomly. @@ -479,6 +479,9 @@ def resample_particles(particle_group, n=0): Number to resample. If n = 0, this will use all particles. + equal_weights: bool, default = False + If True, will ensure that all particles have equal weights. + Returns ------- data: dict of ParticleGroup data @@ -498,19 +501,21 @@ def resample_particles(particle_group, n=0): ixlist = np.random.choice(n_old, n, replace=False) weight = np.full(n, particle_group.charge/n) - # variable weights + # variable weights found + elif equal_weights or n != n_old: + # From SciPy example: + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.rv_discrete.html#scipy.stats.rv_discrete + pk = weight / np.sum(weight) # Probabilities + xk = np.arange(len(pk)) # index + ixsampler = scipy_stats.rv_discrete(name='ixsampler', values=(xk, pk)) + ixlist = ixsampler.rvs(size=n) + weight = np.full(n, particle_group.charge/n) + else: - if n == n_old: - ixlist = np.random.choice(n_old, n, replace=False) - weight = weight[ixlist] #just scramble - else: - # From SciPy example: - # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.rv_discrete.html#scipy.stats.rv_discrete - pk = weight / np.sum(weight) # Probabilities - xk = np.arange(len(pk)) # index - ixsampler = scipy_stats.rv_discrete(name='ixsampler', values=(xk, pk)) - ixlist = ixsampler.rvs(size=n) - weight = np.full(n, particle_group.charge/n) + assert n == n_old + ixlist = np.random.choice(n_old, n, replace=False) + weight = weight[ixlist] #just scramble + data = {} for key in particle_group._settable_array_keys: