Skip to content

Commit

Permalink
Merge pull request #3447 from pleroy/GradientDescentCleanup
Browse files Browse the repository at this point in the history
Change the gradient descent code to not depend on Position
  • Loading branch information
pleroy authored Oct 23, 2022
2 parents b26c192 + 13029b2 commit 5f889a4
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 66 deletions.
42 changes: 24 additions & 18 deletions numerics/gradient_descent.hpp
Original file line number Diff line number Diff line change
@@ -1,33 +1,39 @@


#pragma once

#include <functional>

#include "geometry/grassmann.hpp"
#include "geometry/named_quantities.hpp"
#include "geometry/hilbert.hpp"
#include "quantities/named_quantities.hpp"
#include "quantities/quantities.hpp"

namespace principia {
namespace numerics {
namespace internal_gradient_descent {

using geometry::Position;
using geometry::Vector;
using quantities::Derivative;
using geometry::Hilbert;
using quantities::Difference;
using quantities::Length;

template<typename Value, typename Frame>
using Field = std::function<Value(Position<Frame> const&)>;

template<typename Scalar, typename Frame>
using Gradient = Vector<Derivative<Scalar, Length>, Frame>;

template<typename Scalar, typename Frame>
Position<Frame> BroydenFletcherGoldfarbShanno(
Position<Frame> const& start_position,
Field<Scalar, Frame> const& f,
Field<Gradient<Scalar, Frame>, Frame> const& grad_f,
using quantities::Product;
using quantities::Quotient;

// In this file |Argument| must be such that its difference belongs to a Hilbert
// space.

template<typename Scalar, typename Argument>
using Field = std::function<Scalar(Argument const&)>;

template<typename Scalar, typename Argument>
using Gradient =
Product<Scalar,
Quotient<Difference<Argument>,
typename Hilbert<Difference<Argument>>::Norm²Type>>;

template<typename Scalar, typename Argument>
Argument BroydenFletcherGoldfarbShanno(
Argument const& start_argument,
Field<Scalar, Argument> const& f,
Field<Gradient<Scalar, Argument>, Argument> const& grad_f,
Length const& tolerance);

} // namespace internal_gradient_descent
Expand Down
59 changes: 33 additions & 26 deletions numerics/gradient_descent_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,28 @@ namespace principia {
namespace numerics {
namespace internal_gradient_descent {

using geometry::Displacement;
using geometry::InnerProduct;
using geometry::InnerProductForm;
using geometry::Normalize;
using geometry::SymmetricBilinearForm;
using geometry::SymmetricProduct;
using geometry::Vector;
using quantities::Abs;
using quantities::Quotient;
using quantities::Square;
using quantities::si::Metre;
namespace si = quantities::si;

// The type of Hₖ, which approximates the inverse of the Hessian.
// A helper to use |Argument| with |SymmetricBilinearForm|.
template<typename A>
struct ArgumentHelper;

template<typename Scalar, typename Frame>
using InverseHessian =
SymmetricBilinearForm<Quotient<Square<Length>, Scalar>, Frame, Vector>;
struct ArgumentHelper<Vector<Scalar, Frame>> {
static SymmetricBilinearForm<double, Frame, Vector> InnerProductForm() {
return geometry::InnerProductForm<Frame, Vector>();
}
};

// The line search follows [NW06], algorithms 3.5 and 3.6, which guarantee that
// the chosen step obeys the strong Wolfe conditions.
Expand All @@ -40,18 +46,18 @@ constexpr double c₁ = 1e-4;
constexpr double c₂ = 0.9;
constexpr double α_multiplier = 2;

template<typename Scalar, typename Frame>
template<typename Scalar, typename Argument>
double Zoom(double α_lo,
double α_hi,
Scalar ϕ_α_lo,
Scalar ϕ_α_hi,
Scalar ϕʹ_α_lo,
Scalar const& ϕ_0,
Scalar const& ϕʹ_0,
Position<Frame> const& x,
Displacement<Frame> const& p,
Field<Scalar, Frame> const& f,
Field<Gradient<Scalar, Frame>, Frame> const& grad_f) {
Argument const& x,
Difference<Argument> const& p,
Field<Scalar, Argument> const& f,
Field<Gradient<Scalar, Argument>, Argument> const& grad_f) {
std::optional<Scalar> previous_ϕ_αⱼ;
for (;;) {
DCHECK_NE(α_lo, α_hi);
Expand Down Expand Up @@ -98,11 +104,11 @@ double Zoom(double α_lo,
}
}

template<typename Scalar, typename Frame>
double LineSearch(Position<Frame> const& x,
Displacement<Frame> const& p,
Field<Scalar, Frame> const& f,
Field<Gradient<Scalar, Frame>, Frame> const& grad_f) {
template<typename Scalar, typename Argument>
double LineSearch(Argument const& x,
Difference<Argument> const& p,
Field<Scalar, Argument> const& f,
Field<Gradient<Scalar, Argument>, Argument> const& grad_f) {
auto const ϕ_0 = f(x);
auto const ϕʹ_0 = InnerProduct(grad_f(x), p);
double αᵢ₋₁ = 0; // α₀.
Expand Down Expand Up @@ -140,41 +146,42 @@ double LineSearch(Position<Frame> const& x,
}

// The implementation of BFGS follows [NW06], algorithm 6.18.
template<typename Scalar, typename Frame>
Position<Frame> BroydenFletcherGoldfarbShanno(
Position<Frame> const& start_position,
Field<Scalar, Frame> const& f,
Field<Gradient<Scalar, Frame>, Frame> const& grad_f,
template<typename Scalar, typename Argument>
Argument BroydenFletcherGoldfarbShanno(
Argument const& start_argument,
Field<Scalar, Argument> const& f,
Field<Gradient<Scalar, Argument>, Argument> const& grad_f,
Length const& tolerance) {
// The first step uses vanilla steepest descent.
auto const x₀ = start_position;
auto const x₀ = start_argument;
auto const grad_f_x₀ = grad_f(x₀);

if (grad_f_x₀ == Gradient<Scalar, Frame>{}) {
if (grad_f_x₀ == Gradient<Scalar, Argument>{}) {
return x₀;
}

// We (ab)use the tolerance to determine the first step size. The assumption
// is that, if the caller provides a reasonable value then (1) we won't miss
// "interesting features" of f; (2) the finite differences won't underflow or
// have other unpleasant properties.
Displacement<Frame> const p₀ = -Normalize(grad_f_x₀) * tolerance;
Difference<Argument> const p₀ = -Normalize(grad_f_x₀) * tolerance;

double const α₀ = LineSearch(x₀, p₀, f, grad_f);
auto const x₁ = x₀+ α₀ * p₀;

// Special computation of H₀ using (6.20).
auto const grad_f_x₁ = grad_f(x₁);
Displacement<Frame> const s₀ = x₁ - x₀;
Difference<Argument> const s₀ = x₁ - x₀;
auto const y₀ = grad_f_x₁ - grad_f_x₀;
InverseHessian<Scalar, Frame> const H₀ =
InnerProduct(s₀, y₀) * InnerProductForm<Frame, Vector>() / y₀.Norm²();
auto const H₀ = InnerProduct(s₀, y₀) *
ArgumentHelper<Difference<Argument>>::InnerProductForm() /
y₀.Norm²();

auto xₖ = x₁;
auto grad_f_xₖ = grad_f_x₁;
auto Hₖ = H₀;
for (;;) {
Displacement<Frame> const pₖ = -Hₖ * grad_f_xₖ;
Difference<Argument> const pₖ = -Hₖ * grad_f_xₖ;
if (pₖ.Norm() <= tolerance) {
return xₖ;
}
Expand Down
45 changes: 24 additions & 21 deletions numerics/gradient_descent_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ TEST_F(GradientDescentTest, Quadratic) {
Position<World> const expected_minimum =
World::origin + Displacement<World>({1 * Metre, 2 * Metre, -3 * Metre});
auto const actual_minimum =
BroydenFletcherGoldfarbShanno<Exponentiation<Length, 2>, World>(
/*start_position=*/World::origin,
BroydenFletcherGoldfarbShanno<Exponentiation<Length, 2>, Position<World>>(
/*start_argument=*/World::origin,
field,
gradient,
/*tolerance=*/1 * Micro(Metre));
Expand All @@ -80,8 +80,8 @@ TEST_F(GradientDescentTest, Quartic) {
Position<World> const expected_minimum =
World::origin + Displacement<World>({1 * Metre, 2 * Metre, -3 * Metre});
auto const actual_minimum =
BroydenFletcherGoldfarbShanno<Exponentiation<Length, 4>, World>(
/*start_position=*/World::origin,
BroydenFletcherGoldfarbShanno<Exponentiation<Length, 4>, Position<World>>(
/*start_argument=*/World::origin,
field,
gradient,
/*tolerance=*/1 * Micro(Metre));
Expand Down Expand Up @@ -111,11 +111,12 @@ TEST_F(GradientDescentTest, Gaussian) {

Position<World> const expected_minimum =
World::origin + Displacement<World>({1 * Metre, 2 * Metre, -3 * Metre});
auto const actual_minimum = BroydenFletcherGoldfarbShanno<double, World>(
/*start_position=*/World::origin,
field,
gradient,
/*tolerance=*/1 * Micro(Metre));
auto const actual_minimum =
BroydenFletcherGoldfarbShanno<double, Position<World>>(
/*start_argument=*/World::origin,
field,
gradient,
/*tolerance=*/1 * Micro(Metre));
EXPECT_THAT(
actual_minimum,
AbsoluteErrorFrom(expected_minimum, IsNear(0.82_(1) * Micro(Metre))));
Expand Down Expand Up @@ -143,23 +144,25 @@ TEST_F(GradientDescentTest, Rosenbrock) {
Position<World> const expected_minimum =
World::origin + Displacement<World>({scale, scale, 0 * Metre});
{
auto const actual_minimum = BroydenFletcherGoldfarbShanno<double, World>(
/*start_position=*/World::origin +
Displacement<World>({1.2 * scale, 1.2 * scale, 0 * Metre}),
field,
gradient,
/*tolerance=*/1 * Micro(Metre));
auto const actual_minimum =
BroydenFletcherGoldfarbShanno<double, Position<World>>(
/*start_argument=*/World::origin +
Displacement<World>({1.2 * scale, 1.2 * scale, 0 * Metre}),
field,
gradient,
/*tolerance=*/1 * Micro(Metre));
EXPECT_THAT(
actual_minimum,
AbsoluteErrorFrom(expected_minimum, IsNear(0.96_(1) * Micro(Metre))));
}
{
auto const actual_minimum = BroydenFletcherGoldfarbShanno<double, World>(
/*start_position=*/World::origin +
Displacement<World>({-1.2 * scale, scale, 0 * Metre}),
field,
gradient,
/*tolerance=*/1 * Micro(Metre));
auto const actual_minimum =
BroydenFletcherGoldfarbShanno<double, Position<World>>(
/*start_argument=*/World::origin +
Displacement<World>({-1.2 * scale, scale, 0 * Metre}),
field,
gradient,
/*tolerance=*/1 * Micro(Metre));
EXPECT_THAT(
actual_minimum,
AbsoluteErrorFrom(expected_minimum, IsNear(0.047_(1) * Micro(Metre))));
Expand Down
2 changes: 1 addition & 1 deletion physics/equipotential_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ auto Equipotential<InertialFrame, Frame>::ComputeLine(
// NOTE(phl): Unclear if |length_integration_tolerance| is the right thing to
// use below.
auto const equipotential_position =
BroydenFletcherGoldfarbShanno<Square<SpecificEnergy>, Frame>(
BroydenFletcherGoldfarbShanno<Square<SpecificEnergy>, Position<Frame>>(
degrees_of_freedom.position(),
f,
grad_f,
Expand Down

0 comments on commit 5f889a4

Please sign in to comment.