Skip to content

Commit e44af78

Browse files
committed
Shift async_driver gear into park if exception from workflow graph
1 parent 5a46fa3 commit e44af78

File tree

6 files changed

+80
-8
lines changed

6 files changed

+80
-8
lines changed

phlex/core/framework_graph.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ namespace phlex::experimental {
2323

2424
layer_sentry::~layer_sentry()
2525
{
26+
// To consider: We may want to skip the following logic if the framework prematurely
27+
// needs to shut down. Keeping it enabled allows in-flight folds to
28+
// complete. However, in some cases it may not be desirable to do this.
2629
auto flush_result = counters_.extract(store_->id());
2730
auto flush_store = store_->make_flush();
2831
if (not flush_result.empty()) {
@@ -65,7 +68,16 @@ namespace phlex::experimental {
6568
eoms_.push(nullptr);
6669
}
6770

68-
framework_graph::~framework_graph() = default;
71+
framework_graph::~framework_graph()
72+
{
73+
if (shutdown_on_error_) {
74+
// When in an error state, we need to sanely pop the layer stack and wait for any tasks to finish.
75+
while (!layers_.empty()) {
76+
layers_.pop();
77+
}
78+
graph_.wait_for_all();
79+
}
80+
}
6981

7082
std::size_t framework_graph::execution_counts(std::string const& node_name) const
7183
{
@@ -82,10 +94,14 @@ namespace phlex::experimental {
8294
finalize();
8395
run();
8496
} catch (std::exception const& e) {
97+
driver_.stop();
8598
spdlog::error(e.what());
99+
shutdown_on_error_ = true;
86100
throw;
87101
} catch (...) {
102+
driver_.stop();
88103
spdlog::error("Unknown exception during graph execution");
104+
shutdown_on_error_ = true;
89105
throw;
90106
}
91107

phlex/core/framework_graph.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ namespace phlex::experimental {
186186
std::queue<product_store_ptr> pending_stores_;
187187
flush_counters counters_;
188188
std::stack<layer_sentry> layers_;
189-
bool shutdown_{false};
189+
bool shutdown_on_error_{false};
190190
};
191191
}
192192

phlex/core/fwd.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define PHLEX_CORE_FWD_HPP
33

44
#include "phlex/model/fwd.hpp"
5-
#include "phlex/utilities/async_driver.hpp"
65

76
#include <memory>
87

@@ -20,10 +19,6 @@ namespace phlex::experimental {
2019
using end_of_message_ptr = std::shared_ptr<end_of_message>;
2120
}
2221

23-
namespace phlex {
24-
using framework_driver = experimental::async_driver<data_cell_index_ptr>;
25-
}
26-
2722
#endif // PHLEX_CORE_FWD_HPP
2823

2924
// Local Variables:

phlex/driver.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
#include "phlex/configuration.hpp"
77
#include "phlex/core/fwd.hpp"
88
#include "phlex/model/product_store.hpp"
9+
#include "phlex/utilities/async_driver.hpp"
910

1011
#include <concepts>
1112
#include <memory>
1213

14+
namespace phlex {
15+
using framework_driver = experimental::async_driver<data_cell_index_ptr>;
16+
}
17+
1318
namespace phlex::experimental::detail {
1419

1520
// See note below.

phlex/utilities/async_driver.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,23 @@ namespace phlex::experimental {
5050
return std::exchange(current_, std::nullopt);
5151
}
5252

53+
void stop()
54+
{
55+
// API that should only be called by the framework_graph
56+
gear_ = states::park;
57+
cv_.notify_one();
58+
}
59+
5360
void yield(RT rt)
5461
{
5562
std::unique_lock lock{mutex_};
5663
current_ = std::make_optional(std::move(rt));
5764
cv_.notify_one();
58-
cv_.wait(lock);
65+
cv_.wait(lock, [&] { return !current_.has_value() or gear_ == states::park; });
66+
if (gear_ == states::park) {
67+
// Can only be in park at this point if the framework needs to prematurely shut down
68+
throw std::runtime_error("Framework shutdown");
69+
}
5970
}
6071

6172
private:

test/framework_graph.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "phlex/core/framework_graph.hpp"
2+
#include "phlex/utilities/max_allowed_parallelism.hpp"
23
#include "plugins/layer_generator.hpp"
34

45
#include "catch2/catch_test_macros.hpp"
@@ -39,3 +40,47 @@ TEST_CASE("Make progress with one thread", "[graph]")
3940
CHECK(g.execution_counts("provide_number") == 1000);
4041
CHECK(g.execution_counts("observe_number") == 1000);
4142
}
43+
44+
TEST_CASE("Stop driver when workflow throws exception", "[graph]")
45+
{
46+
experimental::layer_generator gen;
47+
gen.add_layer("spill", {"job", 1000});
48+
49+
experimental::framework_graph g{driver_for_test(gen)};
50+
g.provide(
51+
"throw_exception",
52+
[](data_cell_index const&) -> unsigned int {
53+
throw std::runtime_error("Error to stop driver");
54+
},
55+
concurrency::unlimited)
56+
.output_product("number"_in("spill"));
57+
58+
// Must have at least one downstream node that requires something of the
59+
// provider...otherwise provider will not be executed.
60+
g.observe(
61+
"downstream_of_exception", [](unsigned int) {}, concurrency::unlimited)
62+
.input_family("number"_in("spill"));
63+
64+
CHECK_THROWS(g.execute());
65+
66+
// There are N + 1 potential existing threads for a framework job, where N corresponds
67+
// to the number configured by the user, and 1 corresponds to the separate std::jthread
68+
// created by the async_driver. Each "pull" from the async_driver happens in a
69+
// serialized way. However, once an index has been pulled from the async_driver by the
70+
// flow graph, that index is sent to downstream nodes for further processing.
71+
//
72+
// The first node that processes that index is a provider that immediately throws an
73+
// exception. This places the framework graph in an error state, where the async_driver
74+
// is short-circuited from doing further processing.
75+
//
76+
// We make the assumption that one of those threads will trigger the exception and the
77+
// remaining threads must be permitted to complete.
78+
CHECK(gen.emitted_cells("/job/spill") <=
79+
static_cast<std::size_t>(experimental::max_allowed_parallelism::active_value() + 1));
80+
81+
// A node has not "executed" until it has returned successfully. For that reason,
82+
// neither the "throw_exception" provider nor the "downstream_of_exception" observer
83+
// will have executed.
84+
CHECK(g.execution_counts("throw_exception") == 0ull);
85+
CHECK(g.execution_counts("downstream_of_exception") == 0ull);
86+
}

0 commit comments

Comments
 (0)