Skip to content

Commit

Permalink
Merge pull request #24 from msakai/feature/problem-proxy
Browse files Browse the repository at this point in the history
Pass only proxy for some methods of IsProblem
  • Loading branch information
msakai authored Jun 25, 2024
2 parents cdc49ea + 5f88828 commit 05243be
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 39 deletions.
77 changes: 39 additions & 38 deletions numeric-optimization/src/Numeric/Optimization.hs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ import Data.Default.Class
import Data.Functor.Contravariant
import Data.IORef
import Data.Maybe
import Data.Proxy
import Data.Vector.Storable (Vector)
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
Expand Down Expand Up @@ -377,13 +378,13 @@ class IsProblem prob where
-- | Dimention of a @'Domain' prob@ value.
--
-- @since 0.2.0.0
dim :: prob -> Domain prob -> Int
dim :: Proxy prob -> Domain prob -> Int
dim prob x = VS.length $ toVector prob x

-- | Convert a @'Domain' prob@ value to a storable 'Vector'.
--
-- @since 0.2.0.0
toVector :: prob -> Domain prob -> Vector Double
toVector :: Proxy prob -> Domain prob -> Vector Double
toVector prob x = VS.create $ do
vec <- VSM.new (dim prob x)
writeToMVector prob x vec
Expand All @@ -394,15 +395,15 @@ class IsProblem prob where
-- It can be thought as a variant of 'toVector' in destination-passing style.
--
-- @since 0.2.0.0
writeToMVector :: PrimMonad m => prob -> Domain prob -> VSM.MVector (PrimState m) Double -> m ()
writeToMVector :: PrimMonad m => Proxy prob -> Domain prob -> VSM.MVector (PrimState m) Double -> m ()
writeToMVector prob x ret = VG.imapM_ (VGM.write ret) (toVector prob x)

-- | Convert a storable 'Vector' back to a value of @'Domain' prob@
--
-- The @'Domain' prob@ argument is used as the return value's /shape/.
--
-- @since 0.2.0.0
updateFromVector :: prob -> Domain prob -> Vector Double -> Domain prob
updateFromVector :: Proxy prob -> Domain prob -> Vector Double -> Domain prob

-- | Objective function
--
Expand Down Expand Up @@ -432,16 +433,16 @@ class IsProblem prob => HasGrad prob where
-- | Pair of 'func' and 'grad'
grad' :: prob -> Domain prob -> (Double, Domain prob)
grad' prob x = runST $ do
gret <- VGM.new (dim prob x)
gret <- VGM.new (dim (Proxy :: Proxy prob) x)
y <- grad'M prob x gret
g <- VG.unsafeFreeze gret
return (y, updateFromVector prob x g)
return (y, updateFromVector (Proxy :: Proxy prob) x g)

-- | Similar to 'grad'' but destination passing style is used for gradient vector
grad'M :: PrimMonad m => prob -> Domain prob -> VSM.MVector (PrimState m) Double -> m Double
grad'M prob x gvec = do
let y = func prob x
writeToMVector prob (grad prob x) gvec
writeToMVector (Proxy :: Proxy prob) (grad prob x) gvec
return y

{-# MINIMAL grad | grad' | grad'M #-}
Expand All @@ -460,7 +461,7 @@ class IsProblem prob => HasHessian prob where
--
-- See also <https://hackage.haskell.org/package/ad-4.5.4/docs/Numeric-AD.html#v:hessianProduct>.
hessianProduct :: prob -> Domain prob -> Domain prob -> Domain prob
hessianProduct prob x v = updateFromVector prob x $ hessian prob x LA.#> toVector prob v
hessianProduct prob x v = updateFromVector (Proxy :: Proxy prob) x $ hessian prob x LA.#> toVector (Proxy :: Proxy prob) v

{-# MINIMAL hessian #-}

Expand All @@ -481,15 +482,15 @@ hasOptionalDict = Just Dict
data Constraint

-- | Bounds for unconstrained problems, i.e. (-∞,+∞).
boundsUnconstrained :: IsProblem prob => prob -> Domain prob -> (Domain prob, Domain prob)
boundsUnconstrained :: forall prob. IsProblem prob => Proxy prob -> Domain prob -> (Domain prob, Domain prob)
boundsUnconstrained prob x = (lb, ub)
where
v = toVector prob x
lb = updateFromVector prob x $ VG.map (\_ -> -infinity) v
ub = updateFromVector prob x $ VG.map (\_ -> infinity) v

-- | Whether all lower bounds are -∞ and all upper bounds are +∞.
isUnconstainedBounds :: IsProblem prob => prob -> (Domain prob, Domain prob) -> Bool
isUnconstainedBounds :: forall prob. IsProblem prob => Proxy prob -> (Domain prob, Domain prob) -> Bool
isUnconstainedBounds prob (lb, ub) =
VG.all (\b -> isInfinite b && b < 0) (toVector prob lb) &&
VG.all (\b -> isInfinite b && b > 0) (toVector prob ub)
Expand Down Expand Up @@ -530,9 +531,9 @@ minimize
-> Domain prob -- ^ Initial value
-> IO (Result (Domain prob))
minimize method params prob x0 = do
let x0' = toVector prob x0
ret <- minimizeV method (contramap (updateFromVector prob x0) params) (AsVectorProblem prob x0) x0'
return $ fmap (updateFromVector prob x0) ret
let x0' = toVector (Proxy :: Proxy prob) x0
ret <- minimizeV method (contramap (updateFromVector (Proxy :: Proxy prob) x0) params) (AsVectorProblem prob x0) x0'
return $ fmap (updateFromVector (Proxy :: Proxy prob) x0) ret

minimizeV
:: forall prob. (IsProblem prob, Optionally (HasGrad prob), Optionally (HasHessian prob))
Expand Down Expand Up @@ -930,10 +931,10 @@ data WithGrad prob = WithGrad prob (Domain prob -> Domain prob)

instance IsProblem prob => IsProblem (WithGrad prob) where
type Domain (WithGrad prob) = Domain prob
dim (WithGrad prob _g) = dim prob
updateFromVector (WithGrad prob _g) x0 = updateFromVector prob x0
toVector (WithGrad prob _g) = toVector prob
writeToMVector (WithGrad prob _g) = writeToMVector prob
dim _ = dim (Proxy :: Proxy prob)
updateFromVector _ x0 = updateFromVector (Proxy :: Proxy prob) x0
toVector _ = toVector (Proxy :: Proxy prob)
writeToMVector _ = writeToMVector (Proxy :: Proxy prob)

func (WithGrad prob _g) = func prob
bounds (WithGrad prob _g) = bounds prob
Expand Down Expand Up @@ -962,10 +963,10 @@ data WithHessian prob = WithHessian prob (Domain prob -> Matrix Double)

instance IsProblem prob => IsProblem (WithHessian prob) where
type Domain (WithHessian prob) = Domain prob
dim (WithHessian prob _hess) = dim prob
updateFromVector (WithHessian prob _hess) x0 = updateFromVector prob x0
toVector (WithHessian prob _hess) = toVector prob
writeToMVector (WithHessian prob _g) = writeToMVector prob
dim _ = dim (Proxy :: Proxy prob)
updateFromVector _ x0 = updateFromVector (Proxy :: Proxy prob) x0
toVector _ = toVector (Proxy :: Proxy prob)
writeToMVector _ = writeToMVector (Proxy :: Proxy prob)

func (WithHessian prob _hess) = func prob
bounds (WithHessian prob _hess) = bounds prob
Expand Down Expand Up @@ -993,10 +994,10 @@ data WithBounds prob = WithBounds prob (Domain prob, Domain prob)

instance IsProblem prob => IsProblem (WithBounds prob) where
type Domain (WithBounds prob) = Domain prob
dim (WithBounds prob _bounds) = dim prob
updateFromVector (WithBounds prob _bounds) x0 = updateFromVector prob x0
toVector (WithBounds prob _bounds) = toVector prob
writeToMVector (WithBounds prob _g) = writeToMVector prob
dim _ = dim (Proxy :: Proxy prob)
updateFromVector _ x0 = updateFromVector (Proxy :: Proxy prob) x0
toVector _ = toVector (Proxy :: Proxy prob)
writeToMVector _ = writeToMVector (Proxy :: Proxy prob)

func (WithBounds prob _bounds) = func prob
bounds (WithBounds _prob bounds) = Just bounds
Expand Down Expand Up @@ -1030,10 +1031,10 @@ data WithConstraints prob = WithConstraints prob [Constraint]

instance IsProblem prob => IsProblem (WithConstraints prob) where
type Domain (WithConstraints prob) = Domain prob
dim (WithConstraints prob _constraints) = dim prob
updateFromVector (WithConstraints prob _constraints) x0 = updateFromVector prob x0
toVector (WithConstraints prob _constraints) = toVector prob
writeToMVector (WithConstraints prob _g) = writeToMVector prob
dim _ = dim (Proxy :: Proxy prob)
updateFromVector _ x0 = updateFromVector (Proxy :: Proxy prob) x0
toVector _ = toVector (Proxy :: Proxy prob)
writeToMVector _ = writeToMVector (Proxy :: Proxy prob)

func (WithConstraints prob _constraints) = func prob
bounds (WithConstraints prob _constraints) = bounds prob
Expand Down Expand Up @@ -1066,27 +1067,27 @@ data AsVectorProblem prob = AsVectorProblem prob (Domain prob)

instance IsProblem prob => IsProblem (AsVectorProblem prob) where
type Domain (AsVectorProblem prob) = Vector Double
dim (AsVectorProblem prob x0) _ = dim prob x0
dim _ = VS.length
updateFromVector _ _ = id
toVector _ = id
-- default implementation of 'writeToMVector' is what we want

func (AsVectorProblem prob x0) = func prob . updateFromVector prob x0
func (AsVectorProblem prob x0) = func prob . updateFromVector (Proxy :: Proxy prob) x0
bounds (AsVectorProblem prob _x0) =
case bounds prob of
Nothing -> Nothing
Just (lb, ub) -> Just (toVector prob lb, toVector prob ub)
Just (lb, ub) -> Just (toVector (Proxy :: Proxy prob) lb, toVector (Proxy :: Proxy prob) ub)
constraints (AsVectorProblem prob _x0) = constraints prob

instance HasGrad prob => HasGrad (AsVectorProblem prob) where
grad (AsVectorProblem prob x0) x = toVector prob $ grad prob (updateFromVector prob x0 x)
grad (AsVectorProblem prob x0) x = toVector (Proxy :: Proxy prob) $ grad prob (updateFromVector (Proxy :: Proxy prob) x0 x)
grad' (AsVectorProblem prob x0) x =
case grad' prob (updateFromVector prob x0 x) of
(y, g) -> (y, toVector prob g)
grad'M (AsVectorProblem prob x0) x ret = grad'M prob (updateFromVector prob x0 x) ret
case grad' prob (updateFromVector (Proxy :: Proxy prob) x0 x) of
(y, g) -> (y, toVector (Proxy :: Proxy prob) g)
grad'M (AsVectorProblem prob x0) x ret = grad'M prob (updateFromVector (Proxy :: Proxy prob) x0 x) ret

instance HasHessian prob => HasHessian (AsVectorProblem prob) where
hessian (AsVectorProblem prob x0) x = hessian prob (updateFromVector prob x0 x)
hessianProduct (AsVectorProblem prob x0) x v = toVector prob $ hessianProduct prob (updateFromVector prob x0 x) (updateFromVector prob x0 v)
hessian (AsVectorProblem prob x0) x = hessian prob (updateFromVector (Proxy :: Proxy prob) x0 x)
hessianProduct (AsVectorProblem prob x0) x v = toVector (Proxy :: Proxy prob) $ hessianProduct prob (updateFromVector (Proxy :: Proxy prob) x0 x) (updateFromVector (Proxy :: Proxy prob) x0 v)

-- ------------------------------------------------------------------------
3 changes: 2 additions & 1 deletion numeric-optimization/test/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Test.Hspec
import Control.Exception
import Control.Monad
import Data.IORef
import Data.Proxy
import Numeric.LinearAlgebra (Matrix, (><))
import Numeric.Optimization
import AllClose
Expand Down Expand Up @@ -158,7 +159,7 @@ main = hspec $ do

context "when given rosenbrock function with bounds (-infinity, +infinity)" $
it "returns the global optimum" $ do
let prob = rosenbrock `WithGrad` rosenbrock' `WithBounds` boundsUnconstrained prob (0,0)
let prob = rosenbrock `WithGrad` rosenbrock' `WithBounds` boundsUnconstrained (Proxy :: Proxy ((Double, Double) -> Double)) (0,0)
result <- minimize LBFGSB def prob (-3,-4)
resultSuccess result `shouldBe` True
assertAllClose (def :: Tol Double) (resultSolution result) (1,1)
Expand Down

0 comments on commit 05243be

Please sign in to comment.