Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion phlex/core/declared_output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ namespace phlex::experimental {
tbb::flow::graph& g,
detail::output_function_t&& ft) :
consumer{std::move(name), std::move(predicates)},
node_{g, concurrency, [f = std::move(ft)](message const& msg) -> tbb::flow::continue_msg {
node_{g, concurrency, [this, f = std::move(ft)](message const& msg) -> tbb::flow::continue_msg {
if (not msg.store->is_flush()) {
f(*msg.store);
++calls_;
}
return {};
}}
Expand Down
2 changes: 2 additions & 0 deletions phlex/core/declared_output.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ namespace phlex::experimental {
detail::output_function_t&& ft);

tbb::flow::receiver<message>& port() noexcept;
std::size_t num_calls() const { return calls_; }

private:
tbb::flow::function_node<message> node_;
std::atomic<std::size_t> calls_;
};

using declared_output_ptr = std::unique_ptr<declared_output>;
Expand Down
4 changes: 3 additions & 1 deletion phlex/core/edge_maker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ namespace phlex::experimental {

// Create edges to outputs
for (auto const& [output_name, output_node] : outputs) {
make_edge(source, output_node->port());
for (auto& [_, provider] : providers) {
make_edge(provider->sender(), output_node->port());
}
for (auto const& named_port : producers_.values()) {
if (named_port.to_output == nullptr) {
throw std::runtime_error("Unexpected null output port for " + named_port.node.full());
Expand Down
3 changes: 3 additions & 0 deletions phlex/core/node_catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ namespace phlex::experimental {
if (auto node = providers.get(node_name)) {
return node->num_calls();
}
if (auto node = outputs.get(node_name)) {
return node->num_calls();
}
throw std::runtime_error("Unknown node type with name: "s + node_name);
}
}
6 changes: 6 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ cet_test(
Boost::json
phlex::core
)
cet_test(
output_products
USE_CATCH2_MAIN
SOURCE output_products.cpp
LIBRARIES layer_generator phlex::core spdlog::spdlog
)
cet_test(
data_cell_counting
USE_CATCH2_MAIN
Expand Down
5 changes: 4 additions & 1 deletion test/form/form_test.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
},
form_output: {
cpp: 'form_module',
products: ['sum'],
// FIXME: Should make it possible to *not* write products created by nodes.
// If 'i' and 'j' are omitted from the products sequence below, an error
// is encountered with the message: 'No configuration found for product: j'.
products: ['sum', 'i', 'j'],
},
},
}
67 changes: 67 additions & 0 deletions test/output_products.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// =======================================================================================
// This is a simple test to ensure that data products are "written" or "output" to an
// output node.
//
// N.B. Output nodes will eventually be replaced with preserver nodes.
// =======================================================================================

#include "phlex/core/framework_graph.hpp"
#include "phlex/model/data_cell_index.hpp"
#include "plugins/layer_generator.hpp"

#include "catch2/catch_test_macros.hpp"

#include <ranges>
#include <set>
#include <string>

using namespace phlex;

namespace {
class product_recorder {
public:
explicit product_recorder(std::set<std::string>& products) : products_{&products} {}

void record(experimental::product_store const& store)
{
for (auto const& product_name : store | std::views::keys) {
products_->insert(product_name);
}
}

private:
std::set<std::string>* products_;
};
}

TEST_CASE("Output data products", "[graph]")
{
experimental::layer_generator gen;
gen.add_layer("spill", {"job", 1u});

experimental::framework_graph g{driver_for_test(gen)};

g.provide("provide_number", [](data_cell_index const&) -> int { return 17; })
.output_product("number_from_provider"_in("spill"));

g.transform(
"square_number",
[](int const number) -> int { return number * number; },
concurrency::unlimited)
.input_family("number_from_provider"_in("spill"))
.output_products("squared_number");

std::set<std::string> products_from_nodes;
g.make<product_recorder>(products_from_nodes)
.output("record_numbers", &product_recorder::record, concurrency::serial);

g.execute();

CHECK(g.execution_count("provide_number") == 1u);
CHECK(g.execution_count("square_number") == 1u);
// The "record_numbers" output node should be executed twice: once to receive the data
// store from the "provide_number" provider, and once to receive the data store from the
// "square_number" transform.
CHECK(g.execution_count("record_numbers") == 2u);
CHECK(products_from_nodes == std::set<std::string>{"number_from_provider", "squared_number"});
}
Loading