Skip to content

Commit 80fd7de

Browse files
committed
feat: initial forcings engine implementation
1 parent 52f4354 commit 80fd7de

File tree

5 files changed

+184
-0
lines changed

5 files changed

+184
-0
lines changed

Diff for: include/forcing/ForcingEngineDataProvider.hpp

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "GenericDataProvider.hpp"
2+
#include <NGenConfig.h>
3+
4+
#if NGEN_WITH_PYTHON
5+
# include "bmi/Bmi_Py_Adapter.hpp"
6+
#else
7+
# error ForcingEngineDataProvider requires Python support.
8+
#endif
9+
10+
#include <memory>
11+
12+
namespace data_access {
13+
14+
struct ForcingEngineDataProvider
15+
: public GenericDataProvider
16+
{
17+
static_assert(
18+
ngen::exec_info::with_python,
19+
"ForcingEngineDataProvider requires Python support."
20+
);
21+
22+
explicit ForcingEngineDataProvider(const std::string& init);
23+
24+
~ForcingEngineDataProvider() override;
25+
26+
auto get_available_variable_names()
27+
-> boost::span<const std::string> override;
28+
29+
auto get_data_start_time()
30+
-> long override;
31+
32+
auto get_data_stop_time()
33+
-> long override;
34+
35+
auto record_duration()
36+
-> long override;
37+
38+
auto get_ts_index_for_time(const time_t& epoch_time)
39+
-> size_t override;
40+
41+
auto get_value(const CatchmentAggrDataSelector& selector, ReSampleMethod m)
42+
-> double override;
43+
44+
auto get_values(const CatchmentAggrDataSelector& selector, data_access::ReSampleMethod m)
45+
-> std::vector<double> override;
46+
47+
private:
48+
models::bmi::Bmi_Py_Adapter instance_;
49+
std::vector<std::string> outputs_;
50+
};
51+
52+
} // namespace data_access

Diff for: src/forcing/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ target_include_directories(forcing PUBLIC
1414
)
1515

1616
target_link_libraries(forcing PUBLIC
17+
NGen::config_header
1718
Boost::boost # Headers-only Boost
1819
Threads::Threads
1920
)
@@ -22,4 +23,8 @@ if(NGEN_WITH_NETCDF)
2223
target_link_libraries(forcing PUBLIC NetCDF)
2324
endif()
2425

26+
if(NGEN_WITH_PYTHON)
27+
target_link_libraries(forcing PUBLIC pybind11::embed NGen::core)
28+
endif()
29+
2530
#target_compile_options(forcing PUBLIC -std=c++14 -Wall)

Diff for: src/forcing/ForcingEngineDataProvider.cpp

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#include "ForcingEngineDataProvider.hpp"
2+
3+
#include <boost/core/span.hpp>
4+
#include <boost/utility/string_view.hpp>
5+
6+
namespace data_access {
7+
8+
ForcingEngineDataProvider::ForcingEngineDataProvider(const std::string& init)
9+
: instance_(
10+
"ForcingEngine",
11+
init,
12+
"NextGen_Forcings_Engine.BMIForcingsEngine",
13+
true,
14+
true,
15+
utils::getStdOut()
16+
)
17+
, outputs_(instance_.GetOutputVarNames())
18+
{};
19+
20+
ForcingEngineDataProvider::~ForcingEngineDataProvider() = default;
21+
22+
auto ForcingEngineDataProvider::get_available_variable_names()
23+
-> boost::span<const std::string>
24+
{
25+
return outputs_;
26+
}
27+
28+
auto ForcingEngineDataProvider::get_data_start_time()
29+
-> long
30+
{
31+
// FIXME: Temporary, but most likely incorrect, cast
32+
return static_cast<int64_t>(instance_.GetStartTime());
33+
}
34+
35+
auto ForcingEngineDataProvider::get_data_stop_time()
36+
-> long
37+
{
38+
// FIXME: Temporary, but most likely incorrect, cast
39+
return static_cast<int64_t>(instance_.GetEndTime());
40+
}
41+
42+
auto ForcingEngineDataProvider::record_duration()
43+
-> long
44+
{
45+
// FIXME: Temporary, but most likely incorrect, cast
46+
return static_cast<int64_t>(instance_.GetTimeStep());
47+
}
48+
49+
auto ForcingEngineDataProvider::get_ts_index_for_time(const time_t& epoch_time)
50+
-> size_t
51+
{
52+
// TODO: implementation
53+
throw std::runtime_error{"not implemented"};
54+
}
55+
56+
auto ForcingEngineDataProvider::get_value(const CatchmentAggrDataSelector& selector, ReSampleMethod m)
57+
-> double
58+
{
59+
const std::string var_name = selector.get_variable_name();
60+
const std::string divide_id = selector.get_id();
61+
const std::size_t count = instance_.GetVarNbytes("CAT-ID") / instance_.GetVarItemsize("CAT-ID");
62+
const auto divide_int_id = std::atoi(divide_id.data() + divide_id.find('-'));
63+
const auto ids = boost::span<const int>{static_cast<int*>(instance_.GetValuePtr("CAT-ID")), count};
64+
const auto* pos = std::find(ids.cbegin(), ids.cend(), divide_int_id);
65+
66+
if (pos == std::end(ids)) {
67+
throw std::runtime_error("Failed to get variable");
68+
}
69+
70+
const auto result_index = static_cast<int>(std::distance(ids.cbegin(), pos));
71+
const auto values = boost::span<const double>{
72+
static_cast<double*>(instance_.GetValuePtr(var_name)),
73+
static_cast<std::size_t>(instance_.GetVarNbytes(var_name) / instance_.GetVarItemsize(var_name))
74+
};
75+
76+
return values[result_index];
77+
}
78+
79+
auto ForcingEngineDataProvider::get_values(const CatchmentAggrDataSelector& selector, data_access::ReSampleMethod m)
80+
-> std::vector<double>
81+
{
82+
return {0.0};
83+
}
84+
85+
} // namespace data_access

Diff for: test/CMakeLists.txt

+12
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ ngen_add_test(
417417
)
418418
#endif()
419419

420+
ngen_add_test(
421+
test_bmi_forcing
422+
OBJECTS
423+
forcing/ForcingEngineDataProvider_Test.cpp
424+
LIBRARIES
425+
NGen::core
426+
NGen::forcing
427+
NGen::ngen_bmi
428+
REQUIRES
429+
NGEN_WITH_PYTHON
430+
)
431+
420432
########################## Primary Combined Unit Test Target
421433
ngen_add_test(
422434
test_unit

Diff for: test/forcing/ForcingEngineDataProvider_Test.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include <gtest/gtest.h>
2+
3+
#include "ForcingEngineDataProvider.hpp"
4+
5+
#include <boost/range/combine.hpp>
6+
7+
TEST(ForcingEngineDataProviderTest, Initialization) {
8+
auto interp_ = utils::ngenPy::InterpreterUtil::getInstance();
9+
10+
data_access::ForcingEngineDataProvider provider{
11+
"extern/ngen-forcing/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/config.yml"
12+
};
13+
14+
const auto outputs_test = provider.get_available_variable_names();
15+
const auto outputs_expected = {
16+
"CAT-ID",
17+
"U2D_ELEMENT",
18+
"V2D_ELEMENT",
19+
"LWDOWN_ELEMENT",
20+
"SWDOWN_ELEMENT",
21+
"T2D_ELEMENT",
22+
"Q2D_ELEMENT",
23+
"PSFC_ELEMENT",
24+
"RAINRATE_ELEMENT"
25+
};
26+
27+
for (decltype(auto) output : boost::combine(outputs_test, outputs_expected)) {
28+
EXPECT_EQ(output.get<0>(), output.get<1>());
29+
}
30+
}

0 commit comments

Comments
 (0)