Skip to content

Commit

Permalink
Merge pull request #4086 from pleroy/Speculative
Browse files Browse the repository at this point in the history
Process intervals with predictible bounds in the Stehlé-Zimmermann search
  • Loading branch information
pleroy authored Sep 4, 2024
2 parents e4f16c8 + a2cb676 commit 22cd37e
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 93 deletions.
8 changes: 5 additions & 3 deletions base/bits.hpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
#pragma once

#include <cstdint>

namespace principia {
namespace base {
namespace _bits {
namespace internal {

// Floor log2 of n, or 0 for n = 0. 8 ↦ 3, 7 ↦ 2.
constexpr int FloorLog2(int n);
constexpr std::int64_t FloorLog2(std::int64_t n);

// Greatest power of 2 less than or equal to n. 8 ↦ 8, 7 ↦ 4.
constexpr int PowerOf2Le(int n);
constexpr std::int64_t PowerOf2Le(std::int64_t n);

// Computes bitreversed(bitreversed(n) + 1) assuming that n is represented on
// the given number of bits. For 4 bits:
// 0 ↦ 8 ↦ 4 ↦ C ↦ 2 ↦ A ↦ 6 ↦ E ↦ 1 ↦ 9 ↦ 5 ↦ D ↦ 3 ↦ B ↦ 7 ↦ F
constexpr int BitReversedIncrement(int n, int bits);
constexpr std::int64_t BitReversedIncrement(std::int64_t n, std::int64_t bits);

} // namespace internal

Expand Down
11 changes: 5 additions & 6 deletions base/bits_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

#include "base/bits.hpp"

#include <cstdint>

#include "base/macros.hpp" // 🧙 For CONSTEXPR_DCHECK.
#include "glog/logging.h"

Expand All @@ -12,20 +10,21 @@ namespace base {
namespace _bits {
namespace internal {

constexpr int FloorLog2(int const n) {
constexpr std::int64_t FloorLog2(std::int64_t const n) {
return n == 0 ? 0 : n == 1 ? 0 : FloorLog2(n >> 1) + 1;
}

constexpr int PowerOf2Le(int const n) {
constexpr std::int64_t PowerOf2Le(std::int64_t const n) {
return n == 0 ? 0 : n == 1 ? 1 : PowerOf2Le(n >> 1) << 1;
}

constexpr int BitReversedIncrement(int const n, int const bits) {
constexpr std::int64_t BitReversedIncrement(std::int64_t const n,
std::int64_t const bits) {
if (bits == 0) {
CONSTEXPR_DCHECK(n == 0);
return 0;
}
CONSTEXPR_DCHECK(n >= 0 && n < 1 << bits);
CONSTEXPR_DCHECK(n >= 0 && n < 1LL << bits);
CONSTEXPR_DCHECK(bits > 0 && bits < 32);
// [War03], chapter 7.1 page 105.
std::uint32_t mask = 0x8000'0000;
Expand Down
190 changes: 108 additions & 82 deletions functions/accurate_table_generator_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "functions/accurate_table_generator.hpp"

#include <algorithm>
#include <chrono>
#include <concepts>
#include <future>
#include <limits>
Expand Down Expand Up @@ -41,6 +42,7 @@ using namespace principia::quantities::_elementary_functions;
using namespace principia::quantities::_quantities;

constexpr std::int64_t T_max = 16;
static_assert(T_max >= 1);

template<std::int64_t rows, std::int64_t columns>
FixedMatrix<cpp_int, rows, columns> ToInt(
Expand Down Expand Up @@ -376,7 +378,7 @@ absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousFullSearch(

// [SZ05], section 3.2, proves that T³ = O(M * N). We use a fudge factor of 8
// to avoid starting with too small a value.
auto const T₀ =
std::int64_t const T₀ =
PowerOf2Le(8 * Cbrt(static_cast<double>(M) * static_cast<double>(N)));

// Scale the argument, functions, and polynomials to lie within [1/2, 1[.
Expand Down Expand Up @@ -425,96 +427,120 @@ absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousFullSearch(
build_scaled_polynomial(function_scales[1], polynomials[1])};

// We construct intervals above and below |scaled_argument| and search for
// solutions on each side alternatively.
Interval<cpp_rational> high_interval{
.min = scaled_argument,
.max = scaled_argument + cpp_rational(2 * T₀, N)};
Interval<cpp_rational> low_interval{
.min = scaled_argument - cpp_rational(2 * T₀, N),
.max = scaled_argument};
for (;;) {
{
auto T = T₀;
do {
VLOG(2) << "T = " << T << ", high_interval = " << high_interval;
auto const status_or_solution =
StehléZimmermannSimultaneousSearch<zeroes>(scaled_functions,
scaled_polynomials,
scaled_remainders,
high_interval.midpoint(),
N,
T);
absl::Status const& status = status_or_solution.status();
if (status.ok()) {
return status_or_solution.value() / argument_scale;
} else {
VLOG(2) << "Status = " << status;
if (absl::IsOutOfRange(status)) {
// Halve the interval. Make sure that the new interval is
// contiguous to the segment already explored.
high_interval.max = high_interval.midpoint();
T /= 2;
} else if (absl::IsNotFound(status)) {
// No solutions here, go to the next interval.
break;
// solutions on each side alternatively. The intervals all have the same
// measure, 2 * T₀, and are progressively farther from the
// |starting_argument|.
for (std::int64_t index = 0;; ++index) {
auto const start = std::chrono::system_clock::now();

Interval<cpp_rational> const initial_high_interval{
.min = scaled_argument + cpp_rational(2 * index * T₀, N),
.max = scaled_argument + cpp_rational(2 * (index + 1) * T₀, N)};
Interval<cpp_rational> const initial_low_interval{
.min = scaled_argument - cpp_rational(2 * (index + 1) * T₀, N),
.max = scaled_argument - cpp_rational(2 * index * T₀, N)};

Interval<cpp_rational> high_interval = initial_high_interval;
Interval<cpp_rational> low_interval = initial_low_interval;

// The radii of the intervals remaining to cover above and below the
// `scaled_argument`.
std::int64_t high_T_to_cover = T₀;
std::int64_t low_T_to_cover = T₀;

// When exiting this loop, we have completely processed
// |initial_high_interval| and |initial_low_interval|.
for (;;) {
bool const high_interval_empty = high_interval.empty();
bool const low_interval_empty = low_interval.empty();
if (high_interval_empty && low_interval_empty) {
break;
}

if (!high_interval_empty) {
std::int64_t T = high_T_to_cover;
// This loop exits (breaks or returns) when |T <= T_max| because
// exhaustive search always gives an answer.
for (;;) {
VLOG(2) << "T = " << T << ", high_interval = " << high_interval;
auto const status_or_solution =
StehléZimmermannSimultaneousSearch<zeroes>(
scaled_functions,
scaled_polynomials,
scaled_remainders,
high_interval.midpoint(),
N,
T);
absl::Status const& status = status_or_solution.status();
if (status.ok()) {
return status_or_solution.value() / argument_scale;
} else {
return status;
VLOG(2) << "Status = " << status;
if (absl::IsOutOfRange(status)) {
// Halve the interval. Make sure that the new interval is
// contiguous to the segment already explored.
T /= 2;
high_interval.max = high_interval.min + cpp_rational(2 * T, N);
} else if (absl::IsNotFound(status)) {
// No solutions here, go to the next interval.
high_T_to_cover -= T;
break;
} else {
return status;
}
}
}
} while (T > 0);

// The Stehlé-Zimmermann algorithm doesn't work for T = 0 because the
// lattice becomes singular.
if (T == 0 && AllFunctionValuesHaveDesiredZeroes<zeroes>(
scaled_functions, high_interval.max)) {
return high_interval.max;
}
}
{
auto T = T₀;
do {
VLOG(2) << "T = " << T << ", low_interval = " << low_interval;
auto const status_or_solution =
StehléZimmermannSimultaneousSearch<zeroes>(scaled_functions,
scaled_polynomials,
scaled_remainders,
low_interval.midpoint(),
N,
T);
absl::Status const& status = status_or_solution.status();
if (status.ok()) {
return status_or_solution.value() / argument_scale;
} else {
VLOG(2) << "Status = " << status;
if (absl::IsOutOfRange(status)) {
// Halve the interval. Make sure that the new interval is
// contiguous to the segment already explored.
low_interval.min = low_interval.midpoint();
T /= 2;
} else if (absl::IsNotFound(status)) {
// No solutions here, go to the next interval.
break;
if (!low_interval_empty) {
std::int64_t T = low_T_to_cover;
// This loop exits (breaks or returns) when |T <= T_max| because
// exhaustive search always gives an answer.
for (;;) {
VLOG(2) << "T = " << T << ", low_interval = " << low_interval;
auto const status_or_solution =
StehléZimmermannSimultaneousSearch<zeroes>(
scaled_functions,
scaled_polynomials,
scaled_remainders,
low_interval.midpoint(),
N,
T);
absl::Status const& status = status_or_solution.status();
if (status.ok()) {
return status_or_solution.value() / argument_scale;
} else {
return status;
VLOG(2) << "Status = " << status;
if (absl::IsOutOfRange(status)) {
// Halve the interval. Make sure that the new interval is
// contiguous to the segment already explored.
T /= 2;
low_interval.min = low_interval.max - cpp_rational(2 * T, N);
} else if (absl::IsNotFound(status)) {
// No solutions here, go to the next interval.
low_T_to_cover -= T;
break;
} else {
return status;
}
}
}
} while (T > 0);

// The Stehlé-Zimmermann algorithm doesn't work for T = 0 because the
// lattice becomes singular.
if (T == 0 && AllFunctionValuesHaveDesiredZeroes<zeroes>(
scaled_functions, low_interval.min)) {
return low_interval.min;
}
VLOG_EVERY_N(1, 10) << "high = "
<< DebugString(
static_cast<double>(high_interval.max));
VLOG_EVERY_N(1, 10) << "low = "
<< DebugString(static_cast<double>(low_interval.min));
high_interval = {.min = high_interval.max,
.max = initial_high_interval.max};
low_interval = {.min = initial_low_interval.min,
.max = low_interval.min};
}
VLOG_EVERY_N(1, 10) << "high = "
<< DebugString(static_cast<double>(high_interval.max));
VLOG_EVERY_N(1, 10) << "low = "
<< DebugString(static_cast<double>(low_interval.min));
high_interval = {.min = high_interval.max,
.max = high_interval.max + cpp_rational(2 * T₀, N)};
low_interval = {.min = low_interval.min - cpp_rational(2 * T₀, N),
.max = low_interval.min};

auto const end = std::chrono::system_clock::now();
VLOG(1) << "Search with index " << index << " around " << starting_argument
<< " took "
<< std::chrono::duration_cast<std::chrono::microseconds>(
end - start);
}
}

Expand Down
9 changes: 7 additions & 2 deletions functions/accurate_table_generator_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ TEST_F(AccurateTableGeneratorTest, StehléZimmermannSinCos15) {
/*T=*/1ll << 21);
EXPECT_THAT(u,
IsOkAndHolds(cpp_rational(4785074575333183, 9007199254740992)));
EXPECT_EQ(*u, cpp_rational(static_cast<double>(*u)));
EXPECT_THAT(static_cast<double>(*u),
RelativeErrorFrom(u₀, Lt(1.3e-10)));
{
Expand Down Expand Up @@ -225,7 +226,8 @@ TEST_F(AccurateTableGeneratorTest, StehléZimmermannFullSinCos5NoScaling) {
{remainder_sin_taylor2, remainder_cos_taylor2},
u₀);
EXPECT_THAT(u,
IsOkAndHolds(cpp_rational(4785074604080979, 9007199254740992)));
IsOkAndHolds(cpp_rational(1196268651020245, 2251799813685248)));
EXPECT_EQ(*u, cpp_rational(static_cast<double>(*u)));
EXPECT_THAT(static_cast<double>(*u),
RelativeErrorFrom(u₀, Lt(3.7e-14)));
{
Expand All @@ -242,7 +244,7 @@ TEST_F(AccurateTableGeneratorTest, StehléZimmermannFullSinCos5NoScaling) {
/*base=*/2);
std::string_view mantissa = mathematica;
CHECK(absl::ConsumePrefix(&mantissa, "Times[2^^"));
EXPECT_THAT(mantissa.substr(53, 5), Eq("00000"));
EXPECT_THAT(mantissa.substr(53, 5), Eq("11111"));
}
}

Expand Down Expand Up @@ -285,6 +287,7 @@ TEST_F(AccurateTableGeneratorTest, StehléZimmermannFullSinCos15NoScaling) {
u₀);
EXPECT_THAT(u,
IsOkAndHolds(cpp_rational(4785074575333183, 9007199254740992)));
EXPECT_EQ(*u, cpp_rational(static_cast<double>(*u)));
EXPECT_THAT(static_cast<double>(*u),
RelativeErrorFrom(u₀, Lt(6.1e-9)));
{
Expand Down Expand Up @@ -336,6 +339,7 @@ TEST_F(AccurateTableGeneratorTest, StehléZimmermannFullSinCos15WithScaling) {
x₀);
EXPECT_THAT(x,
IsOkAndHolds(cpp_rational(4785074575333183, 36028797018963968)));
EXPECT_EQ(*x, cpp_rational(static_cast<double>(*x)));
EXPECT_THAT(static_cast<double>(*x),
RelativeErrorFrom(x₀, Lt(6.1e-9)));
{
Expand Down Expand Up @@ -398,6 +402,7 @@ TEST_F(AccurateTableGeneratorTest, StehléZimmermannMultisearchSinCos15) {
for (std::int64_t i = 0; i < xs.size(); ++i) {
CHECK_OK(xs[i].status());
auto const& x = *xs[i];
EXPECT_EQ(x, cpp_rational(static_cast<double>(x)));
EXPECT_THAT(static_cast<double>(x),
RelativeErrorFrom((i + index_begin) / 128.0, Lt(1.3e-7)));
{
Expand Down
4 changes: 4 additions & 0 deletions geometry/interval.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ struct Interval {

// The Lebesgue measure of this interval.
Difference<T> measure() const;

// Returns true iff |measure| would return zero, but more efficiently.
bool empty() const;

// The midpoint of this interval; NaN if the interval is empty (min > max).
T midpoint() const;

Expand Down
5 changes: 5 additions & 0 deletions geometry/interval_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ Difference<T> Interval<T>::measure() const {
return max >= min ? max - min : Difference<T>{};
}

template<typename T>
bool Interval<T>::empty() const {
return max <= min;
}

template<typename T>
T Interval<T>::midpoint() const {
if constexpr (is_number<T>::value) {
Expand Down

0 comments on commit 22cd37e

Please sign in to comment.