Skip to content

Commit

Permalink
Merge pull request #19 from msakai/ad-more-modes
Browse files Browse the repository at this point in the history
Add more AD modes
  • Loading branch information
msakai authored May 23, 2024
2 parents c1f5b1e + 4ed540f commit a3be3d6
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 41 deletions.
1 change: 1 addition & 0 deletions numeric-optimization-ad/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to the
## 0.2.0.0 - Unreleased

* Redesign API using newly introduced `Domain` type.
* Support more AD modes: `Dense`, `Forward`, `Kahn`

## 0.1.0.1 - 2023-06-03

Expand Down
220 changes: 179 additions & 41 deletions numeric-optimization-ad/src/Numeric/Optimization/AD.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
Expand All @@ -16,16 +17,52 @@
-- This module is a wrapper of "Numeric.Optimization" that uses
-- [ad](https://hackage.haskell.org/package/ad)'s automatic differentiation.
--
-- This module provides @Using/Foo/@ types for wrapping functions into
-- optimization problems ('IsProblem') that compute gradients (and
-- hessians) using automatic differentiation of ad's corresponding
-- @Numeric.AD.Mode./Foo/@ module.
--
-- Example:
--
-- > import Numeric.Optimization
-- > import Numeric.Optimization.AD
-- >
-- > main :: IO ()
-- > main = do
-- > result <- minimize LBFGS def (UsingReverse rosenbrock) [-3,-4]
-- > print (resultSuccess result) -- True
-- > print (resultSolution result) -- [0.999999999009131,0.9999999981094296]
-- > print (resultValue result) -- 1.8129771632403013e-18
-- >
-- > -- https://en.wikipedia.org/wiki/Rosenbrock_function
-- > rosenbrock :: Floating a => [a] -> a
-- > -- rosenbrock :: Reifies s Tape => [Reverse s Double] -> Reverse s Double
-- > rosenbrock [x,y] = sq (1 - x) + 100 * sq (y - sq x)
-- >
-- > sq :: Floating a => a -> a
-- > sq x = x ** 2
--
-----------------------------------------------------------------------------
module Numeric.Optimization.AD
(
-- * Problem specification
UsingReverse (..)
#if MIN_VERSION_ad(4,5,0)
UsingDense (..)
,
#endif
UsingForward (..)
, UsingKahn (..)
, UsingReverse (..)
, UsingSparse (..)

-- * Utilities and Re-exports
, AD
, auto
#if MIN_VERSION_ad(4,5,0)
, Dense
#endif
, Forward
, Kahn
, Reverse
, Reifies
, Tape
Expand All @@ -43,6 +80,14 @@ import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import Numeric.AD (AD, auto)
import Numeric.AD.Internal.Reverse (Tape)
#if MIN_VERSION_ad(4,5,0)
import Numeric.AD.Mode.Dense (Dense)
import qualified Numeric.AD.Mode.Dense as Dense
#endif
import Numeric.AD.Mode.Forward (Forward)
import qualified Numeric.AD.Mode.Forward as Forward
import Numeric.AD.Mode.Kahn (Kahn)
import qualified Numeric.AD.Mode.Kahn as Kahn
import Numeric.AD.Mode.Reverse (Reverse)
import qualified Numeric.AD.Mode.Reverse as Reverse
import Numeric.AD.Mode.Sparse (Sparse)
Expand All @@ -52,28 +97,141 @@ import Numeric.Optimization

-- ------------------------------------------------------------------------

#if MIN_VERSION_ad(4,5,0)

-- | Type for defining function and its gradient using automatic differentiation
-- provided by "Numeric.AD.Mode.Reverse".
-- provided by "Numeric.AD.Mode.Dense".
--
-- Example:
-- @since 0.2.0.0
data UsingDense f
= UsingDense (forall s. f (AD s (Dense f Double)) -> AD s (Dense f Double))

instance Traversable f => IsProblem (UsingDense f) where
type Domain (UsingDense f) = f Double

dim _ = length

toVector _ = VG.fromList . toList

writeToMVector _ = writeToMVector'

updateFromVector _ = updateFromVector'

func (UsingDense f) x = fst $ Dense.grad' f x

bounds (UsingDense _f) = Nothing

constraints (UsingDense _f) = []

instance Traversable f => HasGrad (UsingDense f) where
grad (UsingDense f) x = Dense.grad f x

grad' (UsingDense f) x = Dense.grad' f x

grad'M (UsingDense f) x gvec =
case Dense.grad' f x of
(y, g) -> do
writeToMVector' g gvec
return y

instance Traversable f => Optionally (HasGrad (UsingDense f)) where
optionalDict = hasOptionalDict

instance Optionally (HasHessian (UsingDense f)) where
optionalDict = Nothing

#endif

-- ------------------------------------------------------------------------

-- | Type for defining function and its gradient using automatic differentiation
-- provided by "Numeric.AD.Mode.Forward".
--
-- > import Numeric.Optimization
-- > import Numeric.Optimization.AD
-- >
-- > main :: IO ()
-- > main = do
-- > result <- minimize LBFGS def (UsingReverse rosenbrock) [-3,-4]
-- > print (resultSuccess result) -- True
-- > print (resultSolution result) -- [0.999999999009131,0.9999999981094296]
-- > print (resultValue result) -- 1.8129771632403013e-18
-- >
-- > -- https://en.wikipedia.org/wiki/Rosenbrock_function
-- > rosenbrock :: Floating a => [a] -> a
-- > -- rosenbrock :: Reifies s Tape => [Reverse s Double] -> Reverse s Double
-- > rosenbrock [x,y] = sq (1 - x) + 100 * sq (y - sq x)
-- >
-- > sq :: Floating a => a -> a
-- > sq x = x ** 2
-- @since 0.2.0.0
data UsingForward f
= UsingForward (forall s. f (AD s (Forward Double)) -> AD s (Forward Double))

instance Traversable f => IsProblem (UsingForward f) where
type Domain (UsingForward f) = f Double

dim _ = length

toVector _ = VG.fromList . toList

writeToMVector _ = writeToMVector'

updateFromVector _ = updateFromVector'

func (UsingForward f) x = fst $ Forward.grad' f x

bounds (UsingForward _f) = Nothing

constraints (UsingForward _f) = []

instance Traversable f => HasGrad (UsingForward f) where
grad (UsingForward f) x = Forward.grad f x

grad' (UsingForward f) x = Forward.grad' f x

grad'M (UsingForward f) x gvec =
case Forward.grad' f x of
(y, g) -> do
writeToMVector' g gvec
return y

instance Traversable f => Optionally (HasGrad (UsingForward f)) where
optionalDict = hasOptionalDict

instance Optionally (HasHessian (UsingForward f)) where
optionalDict = Nothing

-- ------------------------------------------------------------------------

-- | Type for defining function and its gradient using automatic differentiation
-- provided by "Numeric.AD.Mode.Kahn".
--
-- @since 0.2.0.0
data UsingKahn f
= UsingKahn (forall s. f (AD s (Kahn Double)) -> AD s (Kahn Double))

instance Traversable f => IsProblem (UsingKahn f) where
type Domain (UsingKahn f) = f Double

dim _ = length

toVector _ = VG.fromList . toList

writeToMVector _ = writeToMVector'

updateFromVector _ = updateFromVector'

func (UsingKahn f) x = fst $ Kahn.grad' f x

bounds (UsingKahn _f) = Nothing

constraints (UsingKahn _f) = []

instance Traversable f => HasGrad (UsingKahn f) where
grad (UsingKahn f) x = Kahn.grad f x

grad' (UsingKahn f) x = Kahn.grad' f x

grad'M (UsingKahn f) x gvec =
case Kahn.grad' f x of
(y, g) -> do
writeToMVector' g gvec
return y

instance Traversable f => Optionally (HasGrad (UsingKahn f)) where
optionalDict = hasOptionalDict

instance Optionally (HasHessian (UsingKahn f)) where
optionalDict = Nothing

-- ------------------------------------------------------------------------

-- | Type for defining function and its gradient using automatic differentiation
-- provided by "Numeric.AD.Mode.Reverse".
--
-- @since 0.2.0.0
data UsingReverse f
Expand Down Expand Up @@ -118,27 +276,7 @@ instance Optionally (HasHessian (UsingReverse f)) where
-- | Type for defining function and its gradient and hessian using automatic
-- differentiation provided by "Numeric.AD.Mode.Sparse".
--
-- Unlike 'UsingReverse', it can be used with methods that requires hessian (e.g. 'Newton').
--
-- Example:
--
-- > import Numeric.Optimization
-- > import Numeric.Optimization.AD
-- >
-- > main :: IO ()
-- > main = do
-- > (x, result, stat) <- minimize Newton def (UsingSparse rosenbrock) [-3,-4]
-- > print (resultSuccess result) -- True
-- > print (resultSolution result) -- [0.9999999999999999,0.9999999999999998]
-- > print (resultValue result) -- 1.232595164407831e-32
-- >
-- > -- https://en.wikipedia.org/wiki/Rosenbrock_function
-- > rosenbrock :: Floating a => [a] -> a
-- > -- rosenbrock :: [AD s (Sparse Double)] -> AD s (Sparse Double)
-- > rosenbrock [x,y] = sq (1 - x) + 100 * sq (y - sq x)
-- >
-- > sq :: Floating a => a -> a
-- > sq x = x ** 2
-- It can be used with methods that requires hessian (e.g. 'Newton').
--
-- @since 0.2.0.0
data UsingSparse f
Expand Down

0 comments on commit a3be3d6

Please sign in to comment.