Skip to content

Commit

Permalink
Merge pull request #435 from bluescarni/pr/sgp4_improv
Browse files Browse the repository at this point in the history
Additions to the SGP4 API
  • Loading branch information
bluescarni authored Jul 19, 2024
2 parents 768c643 + 86c25bb commit dc268b9
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 126 deletions.
5 changes: 2 additions & 3 deletions include/heyoka/model/sgp4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ class HEYOKA_DLL_PUBLIC_INLINE_CLASS sgp4_propagator
template <typename Input, typename... KwArgs>
static auto parse_ctor_args(const Input &in, const KwArgs &...kw_args)
{
if (in.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument("Cannot initialise an sgp4_propagator with a null list of satellites");
}
if (in.extent(1) == 0u) [[unlikely]] {
throw std::invalid_argument("Cannot initialise an sgp4_propagator with an empty list of satellites");
}
Expand Down Expand Up @@ -169,6 +166,8 @@ class HEYOKA_DLL_PUBLIC_INLINE_CLASS sgp4_propagator
[[nodiscard]] std::uint32_t get_nouts() const noexcept;
[[nodiscard]] mdspan<const T, extents<std::size_t, 9, std::dynamic_extent>> get_sat_data() const;

void replace_sat_data(mdspan<const T, extents<std::size_t, 9, std::dynamic_extent>>);

[[nodiscard]] std::uint32_t get_diff_order() const noexcept;
[[nodiscard]] const std::vector<expression> &get_diff_args() const;
[[nodiscard]] std::pair<std::uint32_t, std::uint32_t> get_dslice(std::uint32_t) const;
Expand Down
34 changes: 0 additions & 34 deletions src/cfunc_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,21 +452,12 @@ void cfunc<T>::single_eval(out_1d outputs, in_1d inputs, std::optional<in_1d> pa
m_impl->m_nouts, outputs.size()));
}

if (outputs.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument("The outputs array passed to a cfunc cannot be null");
}

if (inputs.size() != m_impl->m_nvars) [[unlikely]] {
throw std::invalid_argument(fmt::format("Invalid inputs array passed to a cfunc: the number of function "
"inputs is {}, but the inputs array has a size of {}",
m_impl->m_nvars, inputs.size()));
}

if (inputs.data_handle() == nullptr && !inputs.empty()) [[unlikely]] {
throw std::invalid_argument(
"The inputs array passed to a cfunc can be null only if the number of input arguments is zero");
}

if (m_impl->m_nparams != 0u && !pars) [[unlikely]] {
throw std::invalid_argument(
"An array of parameter values must be passed in order to evaluate a function with parameters");
Expand All @@ -479,11 +470,6 @@ void cfunc<T>::single_eval(out_1d outputs, in_1d inputs, std::optional<in_1d> pa
"but the number of parameters in the function is {}",
pars->size(), m_impl->m_nparams));
}

if (pars->data_handle() == nullptr && !pars->empty()) [[unlikely]] {
throw std::invalid_argument(
"The array of parameter values passed to a cfunc can be null only if the number of parameters is zero");
}
}

if (m_impl->m_is_time_dependent && !time) [[unlikely]] {
Expand Down Expand Up @@ -719,11 +705,6 @@ void cfunc<T>::multi_eval(out_2d outputs, in_2d inputs, std::optional<in_2d> par
m_impl->m_nouts, outputs.extent(0)));
}

if (outputs.data_handle() == nullptr && !outputs.empty()) [[unlikely]] {
throw std::invalid_argument(
"The outputs array passed to a cfunc can be null only if the number of evaluations is zero");
}

// Fetch the number of columns from outputs.
const auto ncols = outputs.extent(1);

Expand All @@ -740,11 +721,6 @@ void cfunc<T>::multi_eval(out_2d outputs, in_2d inputs, std::optional<in_2d> par
ncols, inputs.extent(1)));
}

if (inputs.data_handle() == nullptr && !inputs.empty()) [[unlikely]] {
throw std::invalid_argument("The inputs array passed to a cfunc can be null only if the number of input "
"arguments or the number of evaluations is zero");
}

if (m_impl->m_nparams != 0u && !pars) [[unlikely]] {
throw std::invalid_argument(
"An array of parameter values must be passed in order to evaluate a function with parameters");
Expand All @@ -765,11 +741,6 @@ void cfunc<T>::multi_eval(out_2d outputs, in_2d inputs, std::optional<in_2d> par
"outputs array is {}",
pars->extent(1), ncols));
}

if (pars->data_handle() == nullptr && !pars->empty()) [[unlikely]] {
throw std::invalid_argument("The array of parameter values passed to a cfunc can be null only if the "
"number of parameters or the number of evaluations is zero");
}
}

if (m_impl->m_is_time_dependent && !times) [[unlikely]] {
Expand All @@ -785,11 +756,6 @@ void cfunc<T>::multi_eval(out_2d outputs, in_2d inputs, std::optional<in_2d> par
"outputs array is {}",
times->size(), ncols));
}

if (times->data_handle() == nullptr && !times->empty()) [[unlikely]] {
throw std::invalid_argument("The array of time values passed to a cfunc can be null only if the "
"number of evaluations is zero");
}
}

#if defined(HEYOKA_HAVE_REAL)
Expand Down
93 changes: 66 additions & 27 deletions src/model/sgp4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ struct sgp4_propagator<T>::impl {
std::vector<T> m_sat_buffer;
std::vector<T> m_init_buffer;
cfunc<T> m_cf_tprop;
cfunc<T> m_cf_init;
std::optional<dtens> m_dtens;
// NOTE: this is a buffer used to convert dates to tsinces in the call operator
// overloads taking dates in input.
Expand All @@ -522,6 +523,7 @@ struct sgp4_propagator<T>::impl {
ar & m_sat_buffer;
ar & m_init_buffer;
ar & m_cf_tprop;
ar & m_cf_init;
ar & m_dtens;
// NOTE: no need to save the content of m_tms_vec.
}
Expand Down Expand Up @@ -625,7 +627,7 @@ sgp4_propagator<T>::sgp4_propagator(ptag, std::tuple<std::vector<T>, cfunc<T>, c
std::vector<T> init_buffer;
init_buffer.resize(boost::safe_numerics::safe<decltype(init_buffer.size())>(n_sats) * cf_init.get_nouts());

// Prepare the in/out spans for invocation of cf_init.
// Prepare the in/out spans for the invocation of cf_init.
// NOTE: for initialisation we only need to read the elements and the bstars from sat_buffer,
// the epochs do not matter. Hence, 7 rows instead of 9.
const typename cfunc<T>::in_2d init_input(sat_buffer.data(), 7, boost::numeric_cast<std::size_t>(n_sats));
Expand All @@ -637,8 +639,8 @@ sgp4_propagator<T>::sgp4_propagator(ptag, std::tuple<std::vector<T>, cfunc<T>, c
cf_init(init_output, init_input);

// Build and assign the implementation.
m_impl = std::make_unique<impl>(
impl{std::move(sat_buffer), std::move(init_buffer), std::move(cf_tprop), std::move(dt), {}});
m_impl = std::make_unique<impl>(impl{
std::move(sat_buffer), std::move(init_buffer), std::move(cf_tprop), std::move(cf_init), std::move(dt), {}});
}

template <typename T>
Expand Down Expand Up @@ -692,6 +694,67 @@ mdspan<const T, extents<std::size_t, 9, std::dynamic_extent>> sgp4_propagator<T>
m_impl->m_sat_buffer.data(), boost::numeric_cast<std::size_t>(m_impl->m_sat_buffer.size() / 9u)};
}

template <typename T>
requires std::same_as<T, double> || std::same_as<T, float>
void sgp4_propagator<T>::replace_sat_data(mdspan<const T, extents<std::size_t, 9, std::dynamic_extent>> new_data)
{
// Cache nsats.
const auto nsats = get_nsats();

if (new_data.extent(1) != nsats) [[unlikely]] {
throw std::invalid_argument(fmt::format("Invalid array provided to replace_sat_data(): the number of "
"columns ({}) does not match the number of satellites ({})",
new_data.extent(1), nsats));
}

// Fetch references to sat_buffer and init_buffer.
auto &sat_buffer = m_impl->m_sat_buffer;
auto &init_buffer = m_impl->m_init_buffer;

// Make copies of the existing data for exception safety.
// NOTE: the concern here is mostly about sat_buffer, since the user may be
// providing invalid data. However, in principle, cf_init could also throw
// when invoked, thus we save also init_buffer.
const auto old_sat_buffer = sat_buffer;
const auto old_init_buffer = init_buffer;

try {
// Write the new data into sat_buffer.
const mdspan<T, extents<std::size_t, 9, std::dynamic_extent>> buffer_span(sat_buffer.data(),
new_data.extent(1));
for (std::size_t i = 0; i < buffer_span.extent(0); ++i) {
for (std::size_t j = 0; j < buffer_span.extent(1); ++j) {
buffer_span(i, j) = new_data(i, j);
}
}

// Check the new data.
detail::sgp4_check_input_satbuf(sat_buffer);

// Fetch a reference to cf_init.
auto &cf_init = m_impl->m_cf_init;

// Prepare the in/out spans for the invocation of cf_init.
// NOTE: for initialisation we only need to read the elements and the bstars from sat_buffer,
// the epochs do not matter. Hence, 7 rows instead of 9.
// NOTE: static casts are ok, we already inited once during construction.
const typename cfunc<T>::in_2d init_input(sat_buffer.data(), 7, static_cast<std::size_t>(nsats));
const typename cfunc<T>::out_2d init_output(init_buffer.data(), static_cast<std::size_t>(cf_init.get_nouts()),
static_cast<std::size_t>(nsats));

// Evaluate the intermediate quantities and their derivatives.
cf_init(init_output, init_input);
} catch (...) {
// Restore the old data before rethrowing.
// NOTE: copy, don't move, as we need to make sure to never
// destroy/reallocate the existing buffers.
std::ranges::copy(old_sat_buffer, sat_buffer.begin());
std::ranges::copy(old_init_buffer, init_buffer.begin());

throw;
}
}

template <typename T>
requires std::same_as<T, double> || std::same_as<T, float>
void sgp4_propagator<T>::check_with_diff(const char *fname) const
Expand Down Expand Up @@ -833,9 +896,6 @@ template <typename T>
void sgp4_propagator<T>::operator()(out_2d out, in_1d<date> dates)
{
// Check the dates array.
if (dates.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument("A null array of dates was passed to the call operator of an sgp4_propagator");
}
const auto n_sats = get_nsats();
if (dates.extent(0) != n_sats) [[unlikely]] {
throw std::invalid_argument(
Expand Down Expand Up @@ -881,23 +941,6 @@ template <typename T>
requires std::same_as<T, double> || std::same_as<T, float>
void sgp4_propagator<T>::operator()(out_3d out, in_2d<T> tms)
{
// NOTE: need to check for nullptr input spans. In the non-batch overload
// we do not need the explicit check because we don't do anything with 'out'
// and 'tms' apart from forwarding them to the cfunc, which does the nullptr check.
// Here however we need to take subspans of 'out' and 'tms' and thus we need to
// pre-check for nullptr in order to avoid undefined behaviour - see the docs for
// the def ctor of mdspan:
//
// https://en.cppreference.com/w/cpp/container/mdspan/mdspan
if (out.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument(
"A null output array was passed to the batch-mode call operator of an sgp4_propagator");
}
if (tms.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument(
"A null times array was passed to the batch-mode call operator of an sgp4_propagator");
}

// Check the dimensionalities of out and tms.
const auto n_evals = out.extent(0);
if (n_evals != tms.extent(0)) [[unlikely]] {
Expand Down Expand Up @@ -972,10 +1015,6 @@ template <typename T>
void sgp4_propagator<T>::operator()(out_3d out, in_2d<date> dates)
{
// Check the dates array.
if (dates.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument(
"A null array of dates was passed to the batch-mode call operator of an sgp4_propagator");
}
const auto n_sats = get_nsats();
if (dates.extent(1) != n_sats) [[unlikely]] {
throw std::invalid_argument(fmt::format(
Expand Down
22 changes: 0 additions & 22 deletions test/cfunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,19 +350,6 @@ TEST_CASE("single call operator")
REQUIRE(output2[1] == -8);
std::ranges::fill(output2, fp_t(0));

// Null output span.
REQUIRE_THROWS_MATCHES(
cf0(typename cfunc<fp_t>::out_1d{nullptr, 2u}, std::array<fp_t, 0>{}, kw::pars = par1, kw::time = fp_t(10)),
std::invalid_argument, Message("The outputs array passed to a cfunc cannot be null"));

// Null input span with inputs.
cf0 = cfunc<fp_t>({x + y - heyoka::time, x - y + par[0]}, {x, y}, kw::opt_level = opt_level,
kw::high_accuracy = high_accuracy, kw::compact_mode = compact_mode);
REQUIRE_THROWS_MATCHES(
cf0(output2, typename cfunc<fp_t>::in_1d{nullptr, 2}, kw::pars = par1, kw::time = fp_t(10)),
std::invalid_argument,
Message("The inputs array passed to a cfunc can be null only if the number of input arguments is zero"));

// Null par span with no pars.
cf0 = cfunc<fp_t>({x + y - heyoka::time, x - y}, {x, y}, kw::opt_level = opt_level,
kw::high_accuracy = high_accuracy, kw::compact_mode = compact_mode);
Expand All @@ -373,15 +360,6 @@ TEST_CASE("single call operator")
REQUIRE(output2[0] == -7);
REQUIRE(output2[1] == -1);
std::ranges::fill(output2, fp_t(0));

// Null par span with pars.
cf0 = cfunc<fp_t>({x + y - heyoka::time, x - y + par[0]}, {x, y}, kw::opt_level = opt_level,
kw::high_accuracy = high_accuracy, kw::compact_mode = compact_mode);
REQUIRE_THROWS_MATCHES(cf0(output2, std::array<fp_t, 2>{1, 2},
kw::pars = typename cfunc<fp_t>::in_1d{nullptr, 1}, kw::time = fp_t(10)),
std::invalid_argument,
Message("The array of parameter values passed to a cfunc can be null only if the number "
"of parameters is zero"));
};

for (auto cm : {false, true}) {
Expand Down
24 changes: 0 additions & 24 deletions test/cfunc_multieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,6 @@ TEST_CASE("multieval st")
// Check no error on zero nevals with null outputs span.
REQUIRE_NOTHROW(cf0(out_2d{nullptr, 2, 0}, in_2d{ibuf.data(), 2, 0}));

// Check error with null outputs span and nonzero evals.
REQUIRE_THROWS_MATCHES(
cf0(out_2d{nullptr, 2, 10}, in_2d{ibuf.data(), 0, 0}), std::invalid_argument,
Message("The outputs array passed to a cfunc can be null only if the number of evaluations is zero"));

obuf.resize(20u);

REQUIRE_THROWS_MATCHES(cf0(out_2d{obuf.data(), 2, 10}, in_2d{ibuf.data(), 0, 0}), std::invalid_argument,
Expand All @@ -113,11 +108,6 @@ TEST_CASE("multieval st")
Message("Invalid inputs array passed to a cfunc: the expected number of columns deduced from the "
"outputs array is 10, but the number of columns in the inputs array is 5"));

// Null input span.
REQUIRE_THROWS_MATCHES(cf0(out_2d{obuf.data(), 2, 10}, in_2d{nullptr, 2, 10}), std::invalid_argument,
Message("The inputs array passed to a cfunc can be null only if the number of input "
"arguments or the number of evaluations is zero"));

cf0 = cfunc<fp_t>{{x + y + par[0], x - y + heyoka::time},
{x, y},
kw::opt_level = opt_level,
Expand Down Expand Up @@ -167,20 +157,6 @@ TEST_CASE("multieval st")
"but the expected size deduced from the "
"outputs array is 10"));

// Null par span.
REQUIRE_THROWS_MATCHES(cf0(out_2d{obuf.data(), 2, 10}, in_2d{ibuf.data(), 2, 10},
kw::pars = in_2d{nullptr, 1, 10}, kw::time = in_1d{tbuf.data(), 5}),
std::invalid_argument,
Message("The array of parameter values passed to a cfunc can be null only if the "
"number of parameters or the number of evaluations is zero"));

// Null time span.
REQUIRE_THROWS_MATCHES(cf0(out_2d{obuf.data(), 2, 10}, in_2d{ibuf.data(), 2, 10},
kw::pars = in_2d{pbuf.data(), 1, 10}, kw::time = in_1d{nullptr, 10}),
std::invalid_argument,
Message("The array of time values passed to a cfunc can be null only if the "
"number of evaluations is zero"));

// Functional testing.
cf0 = cfunc<fp_t>{{x + y, x - y},
{x, y},
Expand Down
Loading

0 comments on commit dc268b9

Please sign in to comment.