Skip to content

Commit

Permalink
Fix crash in multithreaded celer-g4 (#1627)
Browse files Browse the repository at this point in the history
  • Loading branch information
amandalund authored Feb 17, 2025
1 parent 5a44cf5 commit bf95fc8
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 31 deletions.
1 change: 1 addition & 0 deletions app/celer-g4/DetectorConstruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ auto DetectorConstruction::construct_field() const -> FieldData
}

CELER_VALIDATE(false, << "invalid field type '" << field_type << "'");
CELER_ASSERT_UNREACHABLE();
}

//---------------------------------------------------------------------------//
Expand Down
32 changes: 11 additions & 21 deletions app/celer-g4/PGPrimaryGeneratorAction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "celeritas/ext/GeantImporter.hh"
#include "celeritas/ext/GeantUnits.hh"
#include "celeritas/inp/Events.hh"
#include "celeritas/phys/ParticleParams.hh"
#include "celeritas/phys/Primary.hh"

namespace celeritas
Expand All @@ -26,27 +25,26 @@ namespace app
namespace
{
//---------------------------------------------------------------------------//
auto make_particles(std::vector<PDGNumber> const& all_pdg)
std::vector<ParticleId> make_particle_ids(std::vector<PDGNumber> const& all_pdg)
{
CELER_VALIDATE(!all_pdg.empty(),
<< "primary generator has no input particles");

auto* par_table = G4ParticleTable::GetParticleTable();
CELER_ASSERT(par_table);

// Find and convert Geant4 particles
ParticleParams::Input inp;
for (auto pdg : all_pdg)
// Find and Geant4 particles and create fake Particle IDs
std::vector<ParticleId> result(all_pdg.size());
for (auto i : range(all_pdg.size()))
{
auto const& pdg = all_pdg[i];
CELER_EXPECT(pdg);
auto* p = par_table->FindParticle(pdg.get());
CELER_VALIDATE(
p, << "particle with PDG " << pdg.get() << " is not loaded");
inp.push_back(
ParticleParams::ParticleInput::from_import(import_particle(*p)));
result[i] = ParticleId(i);
}

return std::make_shared<ParticleParams>(std::move(inp));
return result;
}

//---------------------------------------------------------------------------//
Expand All @@ -57,19 +55,10 @@ auto make_particles(std::vector<PDGNumber> const& all_pdg)
* Construct primary action.
*/
PGPrimaryGeneratorAction::PGPrimaryGeneratorAction(Input const& i)
: particle_params_{make_particles(i.pdg)}
, generate_primaries_{i, *particle_params_}
: pdg_(i.pdg), generate_primaries_(i, make_particle_ids(i.pdg))
{
// Generate one particle at each call to \c GeneratePrimaryVertex()
gun_.SetNumberOfParticles(1);

// Save the particle definitions corresponding to particle IDs
particle_def_.reserve(i.pdg.size());
for (auto const& pdg : i.pdg)
{
particle_def_.push_back(
G4ParticleTable::GetParticleTable()->FindParticle(pdg.get()));
}
}

//---------------------------------------------------------------------------//
Expand All @@ -92,9 +81,10 @@ void PGPrimaryGeneratorAction::GeneratePrimaries(G4Event* event)

for (auto const& p : primaries)
{
CELER_ASSERT(p.particle_id < particle_def_.size());
CELER_ASSERT(p.particle_id < pdg_.size());
gun_.SetParticleDefinition(
particle_def_[p.particle_id.unchecked_get()]);
G4ParticleTable::GetParticleTable()->FindParticle(
pdg_[p.particle_id.unchecked_get()].get()));
gun_.SetParticlePosition(convert_to_geant(p.position, clhep_length));
gun_.SetParticleMomentumDirection(convert_to_geant(p.direction, 1));
gun_.SetParticleEnergy(convert_to_geant(p.energy, CLHEP::MeV));
Expand Down
3 changes: 1 addition & 2 deletions app/celer-g4/PGPrimaryGeneratorAction.hh
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,9 @@ class PGPrimaryGeneratorAction final : public G4VUserPrimaryGeneratorAction
private:
using GeneratorImpl = celeritas::PrimaryGenerator;

GeneratorImpl::SPConstParticles particle_params_;
std::vector<PDGNumber> pdg_;
GeneratorImpl generate_primaries_;
G4ParticleGun gun_;
std::vector<G4ParticleDefinition*> particle_def_;
};

//---------------------------------------------------------------------------//
Expand Down
38 changes: 30 additions & 8 deletions src/celeritas/phys/PrimaryGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//---------------------------------------------------------------------------//
#include "PrimaryGenerator.hh"

#include <random>
#include <utility>

#include "corecel/cont/Range.hh"
#include "corecel/cont/VariantUtils.hh"
Expand Down Expand Up @@ -73,6 +73,22 @@ auto make_direction_sampler(inp::AngleDistribution const& i)
i);
}

//---------------------------------------------------------------------------//
/*!
* Get a vector of particle IDs from PDG number.
*/
std::vector<ParticleId> make_particle_ids(std::vector<PDGNumber> const& pdgs,
ParticleParams const& particles)
{
std::vector<ParticleId> result;
result.reserve(pdgs.size());
for (auto const& pdg : pdgs)
{
result.push_back(particles.find(pdg));
}
return result;
}

//---------------------------------------------------------------------------//
} // namespace

Expand All @@ -95,30 +111,36 @@ PrimaryGenerator::from_options(SPConstParticles particles,

//---------------------------------------------------------------------------//
/*!
* Construct with options and shared particle data.
* Construct with options and particle IDs.
*/
PrimaryGenerator::PrimaryGenerator(Input const& i,
ParticleParams const& particles)
std::vector<ParticleId> particle_id)
: num_events_{i.num_events}
, primaries_per_event_{i.primaries_per_event}
, seed_{i.seed}
, sample_energy_{make_energy_sampler(i.energy)}
, sample_pos_{make_position_sampler(i.shape)}
, sample_dir_{make_direction_sampler(i.angle)}
, particle_id_(std::move(particle_id))
{
// TODO: seed based on event
this->seed(UniqueEventId{0});

particle_id_.reserve(i.pdg.size());
for (auto const& pdg : i.pdg)
{
particle_id_.push_back(particles.find(pdg));
}
CELER_VALIDATE(
std::all_of(particle_id_.begin(), particle_id_.end(), LogicalTrue{}),
<< R"(invalid or missing particle types specified for primary generator)");
}

//---------------------------------------------------------------------------//
/*!
* Construct with options and shared particle data.
*/
PrimaryGenerator::PrimaryGenerator(Input const& i,
ParticleParams const& particles)
: PrimaryGenerator(i, make_particle_ids(i.pdg, particles))
{
}

//---------------------------------------------------------------------------//
/*!
* Generate primary particles from a single event.
Expand Down
3 changes: 3 additions & 0 deletions src/celeritas/phys/PrimaryGenerator.hh
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class PrimaryGenerator : public EventReaderInterface
// Construct from shared particle data and new input
PrimaryGenerator(Input const&, ParticleParams const& particles);

// Construct from particle IDs and new input
PrimaryGenerator(Input const&, std::vector<ParticleId> particle_ids);

//! Prevent copying and moving
CELER_DELETE_COPY_MOVE(PrimaryGenerator);
~PrimaryGenerator() override = default;
Expand Down
3 changes: 3 additions & 0 deletions src/celeritas/phys/PrimaryGeneratorOptions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ inp::EnergyDistribution inp_from_energy(DistributionOptions const& options)
<< to_cstring(options.distribution) << "' for "
<< sampler_name);
}
CELER_ASSERT_UNREACHABLE();
}

//---------------------------------------------------------------------------//
Expand All @@ -84,6 +85,7 @@ inp::ShapeDistribution inp_from_position(DistributionOptions const& options)
<< to_cstring(options.distribution) << "' for "
<< sampler_name);
}
CELER_ASSERT_UNREACHABLE();
}

//---------------------------------------------------------------------------//
Expand All @@ -105,6 +107,7 @@ inp::AngleDistribution inp_from_direction(DistributionOptions const& options)
<< to_cstring(options.distribution) << "' for "
<< sampler_name);
}
CELER_ASSERT_UNREACHABLE();
}

//---------------------------------------------------------------------------//
Expand Down
1 change: 1 addition & 0 deletions src/orange/detail/OrangeInputIOImpl.json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ VariantTransform import_transform(nlohmann::json const& src)
<< "invalid number of elements in transform: "
<< data.size());
}
CELER_ASSERT_UNREACHABLE();
}

//---------------------------------------------------------------------------//
Expand Down

0 comments on commit bf95fc8

Please sign in to comment.