Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 54 additions & 8 deletions src/libraries/JANA/Components/JHasInputs.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ struct JHasInputs {

std::vector<InputBase*> m_inputs;
std::vector<VariadicInputBase*> m_variadic_inputs;
std::vector<std::pair<InputBase*, VariadicInputBase*>> m_ordered_inputs;

void RegisterInput(InputBase* input) {
m_inputs.push_back(input);
m_ordered_inputs.push_back({input, nullptr});
}

void RegisterInput(VariadicInputBase* input) {
m_variadic_inputs.push_back(input);
m_ordered_inputs.push_back({nullptr, input});
}

struct InputOptions {
Expand Down Expand Up @@ -146,14 +149,12 @@ struct JHasInputs {
class Input : public InputBase {

std::vector<const T*> m_data;
std::string m_tag;

public:

Input(JHasInputs* owner) {
owner->RegisterInput(this);
m_type_name = JTypeInfo::demangle<T>();
m_databundle_name = m_type_name;
m_level = JEventLevel::None;
}

Expand All @@ -164,8 +165,7 @@ struct JHasInputs {
}

void SetTag(std::string tag) {
m_tag = tag;
m_databundle_name = m_type_name + ":" + tag;
m_databundle_name = tag;
}

const std::vector<const T*>& operator()() { return m_data; }
Expand All @@ -180,23 +180,23 @@ struct JHasInputs {
auto& level = m_level;
m_data.clear();
if (level == event.GetLevel() || level == JEventLevel::None) {
event.Get<T>(m_data, m_tag, !m_is_optional);
event.Get<T>(m_data, m_databundle_name, !m_is_optional);
}
else {
if (m_is_optional && !event.HasParent(level)) return;
event.GetParent(level).template Get<T>(m_data, m_tag, !m_is_optional);
event.GetParent(level).template Get<T>(m_data, m_databundle_name, !m_is_optional);
}
}
void PrefetchCollection(const JEvent& event) {
if (m_level == event.GetLevel() || m_level == JEventLevel::None) {
auto fac = event.GetFactory<T>(m_tag, !m_is_optional);
auto fac = event.GetFactory<T>(m_databundle_name, !m_is_optional);
if (fac != nullptr) {
fac->Create(event);
}
}
else {
if (m_is_optional && !event.HasParent(m_level)) return;
auto fac = event.GetParent(m_level).template GetFactory<T>(m_tag, !m_is_optional);
auto fac = event.GetParent(m_level).template GetFactory<T>(m_databundle_name, !m_is_optional);
if (fac != nullptr) {
fac->Create(event);
}
Expand Down Expand Up @@ -467,6 +467,11 @@ struct JHasInputs {
const std::vector<JEventLevel>& variadic_input_levels,
const std::vector<std::vector<std::string>>& variadic_input_databundle_names) {

if (m_variadic_inputs.size() == 1 && variadic_input_databundle_names.size() == 0) {
WireInputsCompatibility(component_level, single_input_levels, single_input_databundle_names);
return;
}

// Validate that we have the correct number of input databundle names
if (single_input_databundle_names.size() != m_inputs.size()) {
throw JException("Wrong number of (nonvariadic) input databundle names! Expected %d, found %d", m_inputs.size(), single_input_databundle_names.size());
Expand Down Expand Up @@ -501,6 +506,47 @@ struct JHasInputs {
}
}

void WireInputsCompatibility(JEventLevel component_level,
const std::vector<JEventLevel>& single_input_levels,
const std::vector<std::string>& single_input_databundle_names) {

// Figure out how many collection names belong to the variadic input
int variadic_databundle_count = single_input_databundle_names.size() - m_inputs.size();
int databundle_name_index = 0;
int databundle_level_index = 0;

for (auto& pair : m_ordered_inputs) {
auto* single_input = pair.first;
auto* variadic_input = pair.second;
if (single_input != nullptr) {
single_input->SetDatabundleName(single_input_databundle_names.at(databundle_name_index));
if (single_input_levels.empty()) {
single_input->SetLevel(component_level);
}
else {
single_input->SetLevel(single_input_levels.at(databundle_level_index));
}
databundle_name_index += 1;
databundle_level_index += 1;
}
else {
std::vector<std::string> variadic_databundle_names;
for (int i=0; i<variadic_databundle_count; ++i) {
variadic_databundle_names.push_back(single_input_databundle_names.at(databundle_name_index+i));
}
variadic_input->SetRequestedDatabundleNames(variadic_databundle_names);
if (single_input_levels.empty()) {
variadic_input->SetLevel(component_level);
}
else {
variadic_input->SetLevel(single_input_levels.at(databundle_level_index)); // Last one wins!
}
databundle_name_index += variadic_databundle_count;
databundle_level_index += 1;
}
}
}

void SummarizeInputs(JComponentSummary::Component& summary) const {
for (const auto* input : m_inputs) {
summary.AddInput(new JComponentSummary::Collection("", input->GetDatabundleName(), input->GetTypeName(), input->GetLevel()));
Expand Down
49 changes: 41 additions & 8 deletions src/libraries/JANA/Components/JHasOutputs.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct JHasOutputs {
protected:

void CreateHelperFactory(JMultifactory& fac) override {
fac.DeclareOutput<T>(this->collection_names[0]);
fac.DeclareOutput<T>(this->collection_names[0], !is_not_owner);
}

void SetCollection(JMultifactory& fac) override {
Expand All @@ -72,7 +72,9 @@ struct JHasOutputs {
auto fac = event.Insert(m_data, this->collection_names[0]);
fac->SetNotOwnerFlag(is_not_owner);
}
void Reset() override { }
void Reset() override {
m_data.clear();
}

};

Expand Down Expand Up @@ -174,7 +176,7 @@ struct JHasOutputs {

void Reset() override {
m_data.clear();
for (auto& coll_name : this->collection_names) {
for (size_t i=0; i<collection_names.size(); ++i) {
m_data.push_back(std::make_unique<typename PodioT::collection_type>());
}
}
Expand All @@ -186,14 +188,45 @@ struct JHasOutputs {
size_t single_output_index = 0;
size_t variadic_output_index = 0;

size_t variadic_output_count = 0;
for (auto* output : m_outputs) {
output->collection_names.clear();
output->level = component_level;
if (output->is_variadic) {
output->collection_names = variadic_output_databundle_names.at(variadic_output_index++);
variadic_output_count += 1;
}
else {
output->collection_names.push_back(single_output_databundle_names.at(single_output_index++));
}
if (variadic_output_count == 1 && variadic_output_databundle_names.size() == 0) {
// Obtain variadic databundle names from excess single-output databundle names
int variadic_databundle_count = single_output_databundle_names.size() - m_outputs.size() + 1;
int current_databundle_index = 0;

for (auto* output : m_outputs) {
output->collection_names.clear();
output->level = component_level;
if (output->is_variadic) {
std::vector<std::string> variadic_names;
for (int i=0; i<variadic_databundle_count; ++i) {
variadic_names.push_back(single_output_databundle_names.at(current_databundle_index+i));
}
output->collection_names = variadic_names;
current_databundle_index += variadic_databundle_count;
}
else {
output->collection_names.push_back(single_output_databundle_names.at(current_databundle_index));
current_databundle_index += 1;
}
}
}
else {
// Do the obvious, sensible thing instead
for (auto* output : m_outputs) {
output->collection_names.clear();
output->level = component_level;
if (output->is_variadic) {
output->collection_names = variadic_output_databundle_names.at(variadic_output_index++);
}
else {
output->collection_names.push_back(single_output_databundle_names.at(single_output_index++));
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/programs/unit_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ set(TEST_SOURCES
Components/JFactoryGeneratorTests.cc
Components/JHasInputsTests.cc
Components/JMultiFactoryTests.cc
Components/JOmniFactoryTests.cc
Components/JServiceTests.cc
Components/UnfoldTests.cc
Components/UserExceptionTests.cc
Expand Down
152 changes: 152 additions & 0 deletions src/programs/unit_tests/Components/JOmniFactoryTests.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@

#include <catch.hpp>
#include <JANA/JVersion.h>
#include <JANA/JApplicationFwd.h>
#include <JANA/Components/JOmniFactory.h>
#include <JANA/Components/JOmniFactoryGeneratorT.h>

struct MyHit {int e; };

#if JANA2_HAVE_PODIO

#include <PodioDatamodel/ExampleHitCollection.h>

struct MyFac : public JOmniFactory<MyFac> {
Input<MyHit> hits_in {this};
VariadicPodioInput<ExampleHit> variadic_podio_hits_in {this};
PodioInput<ExampleHit> podio_hits_in {this};

Output<MyHit> hits_out {this};
VariadicPodioOutput<ExampleHit> variadic_podio_hits_out {this};
PodioOutput<ExampleHit> podio_hits_out {this};

MyFac() {
hits_out.SetNotOwnerFlag(true);
}
void Configure() { }
void ChangeRun(int32_t /*run_nr*/) { }
void Execute(int32_t /*run_nr*/, uint64_t /*evt_nr*/) {

REQUIRE(hits_out().size() == 0);
REQUIRE(podio_hits_out()->size() == 0);
REQUIRE(variadic_podio_hits_out().size() == 2);
REQUIRE(variadic_podio_hits_out().at(0)->size() == 0);
REQUIRE(variadic_podio_hits_out().at(1)->size() == 0);

for (auto hit : *hits_in) {
hits_out().push_back(const_cast<MyHit*>(hit));
}

podio_hits_out()->setSubsetCollection();
variadic_podio_hits_out().at(0)->setSubsetCollection();

for (auto hit : *podio_hits_in) {
podio_hits_out()->push_back(hit);
variadic_podio_hits_out().at(0)->push_back(hit);
}

variadic_podio_hits_out().at(1)->push_back(MutableExampleHit(22, 1.1, 1.1, 1.1, 10, 0));
}
};

void test_single_event(JEvent& event) {
event.Insert(new MyHit{99}, "lw");
ExampleHitCollection coll;
coll.push_back(MutableExampleHit{14,0.0,0.0,0.0,100,0});
coll.push_back(MutableExampleHit{21,0.0,0.0,0.0,100,0});
event.InsertCollection<ExampleHit>(std::move(coll), "podio");

ExampleHitCollection coll2;
coll2.push_back(MutableExampleHit{30,0.0,0.0,0.0,100,0});
event.InsertCollection<ExampleHit>(std::move(coll2), "v_podio_0");

ExampleHitCollection coll3;
coll3.push_back(MutableExampleHit{10101,0.0,0.0,0.0,100,0});
event.InsertCollection<ExampleHit>(std::move(coll3), "v_podio_1");

auto hits = event.Get<MyHit>("lw2");
REQUIRE(hits.size() == 1);
REQUIRE(hits.at(0)->e == 99);

auto podio_hits = event.GetCollection<ExampleHit>("podio2");
REQUIRE(podio_hits->size() == 2);
REQUIRE(podio_hits->at(0).cellID() == 14);
REQUIRE(podio_hits->at(1).cellID() == 21);

auto podio_hits_v2 = event.GetCollection<ExampleHit>("v_podio_2");
REQUIRE(podio_hits_v2 ->size() == 2);
REQUIRE(podio_hits_v2->at(0).cellID() == 14);
REQUIRE(podio_hits_v2->at(1).cellID() == 21);

auto podio_hits_v3 = event.GetCollection<ExampleHit>("v_podio_3");
REQUIRE(podio_hits_v3 ->size() == 1);
REQUIRE(podio_hits_v3->at(0).cellID() == 22);
}

TEST_CASE("JOmniFactoryTests_VariadicWiring") {
JApplication app;
app.Add(new JOmniFactoryGeneratorT<MyFac>(
"sut",
{"lw", "v_podio_0", "v_podio_1", "podio"},
{"lw2", "v_podio_2", "v_podio_3", "podio2"}));

auto event = std::make_shared<JEvent>(&app);
test_single_event(*event);
event->Clear();
test_single_event(*event);
}

#endif


struct MyFac2 : public JOmniFactory<MyFac2> {
Output<MyHit> hits_out {this};

void Configure() { }
void ChangeRun(int32_t /*run_nr*/) { }
void Execute(int32_t /*run_nr*/, uint64_t /*evt_nr*/) {
REQUIRE(hits_out().size() == 0);
hits_out().push_back(new MyHit{22});
}
};


TEST_CASE("JOmniFactoryTests_OutputsCleared") {
JApplication app;
app.Add(new JOmniFactoryGeneratorT<MyFac2>(
"sut",
{},
{"huegelgrab"}));

auto event = std::make_shared<JEvent>(&app);
REQUIRE(event->GetSingleStrict<MyHit>("huegelgrab")->e == 22);

event->Clear();
REQUIRE(event->GetSingleStrict<MyHit>("huegelgrab")->e == 22);
}

struct MyFac3 : public JOmniFactory<MyFac3> {
Input<MyHit> hits_in {this};
Output<MyHit> hits_out {this};

void Configure() { }
void ChangeRun(int32_t /*run_nr*/) { }
void Execute(int32_t /*run_nr*/, uint64_t /*evt_nr*/) {
REQUIRE(hits_in().at(0)->e == 123);
hits_out().push_back(new MyHit{1234});
}
};


TEST_CASE("JOmniFactoryTests_LightweightInputTag") {
JApplication app;
app.Add(new JOmniFactoryGeneratorT<MyFac3>(
"sut",
{"huegelgrab"},
{"schlafen"}));

auto event = std::make_shared<JEvent>(&app);
event->Insert<MyHit>(new MyHit{123}, "huegelgrab");
REQUIRE(event->GetSingleStrict<MyHit>("schlafen")->e == 1234);

}
Loading