Skip to content

Commit

Permalink
Complete the ADVal instance for mixed tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 9, 2025
1 parent aeccd7d commit 1122d9b
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 46 deletions.
178 changes: 171 additions & 7 deletions src/HordeAd/Core/Delta.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,17 @@ import Type.Reflection (typeRep)

import Data.Array.Mixed.Permutation (permInverse)
import Data.Array.Mixed.Permutation qualified as Permutation
import Data.Array.Mixed.Shape (shxAppend, shxDropSSX, shxTakeSSX, withKnownShX)
import Data.Array.Mixed.Shape
( IShX
, StaticShX
, shxAppend
, shxCast'
, shxDropSSX
, shxTail
, shxTakeSSX
, ssxFromShape
, withKnownShX
)
import Data.Array.Mixed.Types (unsafeCoerceRefl)
import Data.Array.Nested
( IShR
Expand All @@ -81,6 +91,7 @@ import Data.Array.Nested
, Replicate
, ShR (..)
, ShS (..)
, ShX (..)
, type (++)
)
import Data.Array.Nested qualified as Nested
Expand Down Expand Up @@ -561,8 +572,7 @@ data Delta :: Target -> TensorKindType -> Type where
-> Delta target (TKS2 (Permutation.PermutePrefix perm sh) r)
-- ^ Transpose according to the permutation.
ReshapeS :: ( TensorKind r, KnownShS sh, KnownShS sh2
, Nested.Product sh
~ Nested.Product sh2 )
, Nested.Product sh ~ Nested.Product sh2 )
=> Delta target (TKS2 sh r)
-> Delta target (TKS2 sh2 r)
-- ^ Change the shape of the tensor from the first to the second.
Expand Down Expand Up @@ -595,6 +605,48 @@ data Delta :: Target -> TensorKindType -> Type where
=> Delta target (TKX2 (sh1 ++ sh2) r)
-> IxXOf target sh1
-> Delta target (TKX2 sh2 r)
Sum0X :: (TensorKind r, KnownShX sh)
=> Delta target (TKX2 sh r) -> Delta target (TKX2 '[] r)
Dot0X :: (GoodScalar r, KnownShX sh)
=> target (TKX sh r) -> Delta target (TKX sh r)
-> Delta target (TKX '[] r)
ScatterX :: forall target r shm shn shp.
( TensorKind r, KnownShX shm, KnownShX shn, KnownShX shp
, KnownShX (shm ++ shn) ) -- needed for the Show instance
=> IShX (shp ++ shn) -> Delta target (TKX2 (shm ++ shn) r)
-> (IxXOf target shm -> IxXOf target shp)
-> Delta target (TKX2 (shp ++ shn) r)
AppendX :: forall target r sh.
(TensorKind r, KnownShX sh)
=> Delta target (TKX2 (Nothing ': sh) r)
-> Delta target (TKX2 (Nothing ': sh) r)
-> Delta target (TKX2 (Nothing ': sh) r)
SliceX :: forall target i n k r sh.
(TensorKind r, KnownNat i, KnownNat n, KnownNat k, KnownShX sh)
=> Delta target (TKX2 (Just (i + n + k) ': sh) r)
-> Delta target (TKX2 (Just n ': sh) r)
ReverseX :: (TensorKind r, KnownShX sh)
=> Delta target (TKX2 (mn ': sh) r)
-> Delta target (TKX2 (mn ': sh) r)
TransposeX :: forall perm sh r target.
(TensorKind r, PermC perm, KnownShX sh, Rank perm <= Rank sh)
=> Permutation.Perm perm
-> Delta target (TKX2 sh r)
-> Delta target (TKX2 (Permutation.PermutePrefix perm sh) r)
ReshapeX :: (TensorKind r, KnownShX sh, KnownShX sh2)
=> IShX sh2 -> Delta target (TKX2 sh r)
-> Delta target (TKX2 sh2 r)
MCastX :: (TensorKind x, KnownShX sh)
=> StaticShX sh2 -> Delta target (TKX2 sh x)
-> Delta target (TKX2 sh2 x)
GatherX :: forall target r shm shn shp.
( TensorKind r, KnownShX shm, KnownShX shn, KnownShX shp
, KnownShX (shp ++ shn) ) -- needed for the Show instance
=> IShX (shm ++ shn) -> Delta target (TKX2 (shp ++ shn) r)
-> (IxXOf target shm -> IxXOf target shp)
-> Delta target (TKX2 (shm ++ shn) r)
CastX :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2, KnownShX sh)
=> Delta target (TKX sh r1) -> Delta target (TKX sh r2)
ZipX :: (TensorKind y, TensorKind z, KnownShX sh)
=> Delta target (TKProduct (TKX2 sh y) (TKX2 sh z))
-> Delta target (TKX2 sh (TKProduct y z))
Expand Down Expand Up @@ -747,7 +799,7 @@ shapeDeltaFull = \case
FTKS _ x -> FTKS knownShS x
Sum0S d -> case shapeDeltaFull d of
FTKS _ x -> FTKS ZSS x
Dot0S{} -> FTKS knownShS FTKScalar
Dot0S{} -> FTKS ZSS FTKScalar
ScatterS @_ @_ @_ @shn @shp d _ -> case shapeDeltaFull d of
FTKS _ x -> FTKS (knownShS @shp `shsAppend` knownShS @shn) x
AppendS a _ -> case shapeDeltaFull a of
Expand All @@ -756,9 +808,7 @@ shapeDeltaFull = \case
FTKS _ x -> FTKS knownShS x
ReverseS d -> shapeDeltaFull d
TransposeS @_ @sh2 perm d -> case shapeDeltaFull d of
FTKS _ x ->
withKnownShS (shsPermutePrefix perm (knownShS @sh2)) $
FTKS knownShS x
FTKS _ x -> FTKS (shsPermutePrefix perm (knownShS @sh2)) x
ReshapeS d -> case shapeDeltaFull d of
FTKS _ x -> FTKS knownShS x
GatherS @_ @_ @shm @shn d _ -> case shapeDeltaFull d of
Expand All @@ -772,6 +822,25 @@ shapeDeltaFull = \case

IndexX @sh1 d _ix -> case shapeDeltaFull d of
FTKX sh x -> FTKX (shxDropSSX sh (knownShX @sh1)) x
Sum0X d -> case shapeDeltaFull d of
FTKX _ x -> FTKX ZSX x
Dot0X{} -> FTKX ZSX FTKScalar
ScatterX sh d _ -> case shapeDeltaFull d of
FTKX _ x -> FTKX sh x
AppendX a _ -> shapeDeltaFull a
SliceX @_ @_ @n d -> case shapeDeltaFull d of
FTKX sh x -> FTKX (Nested.SKnown (SNat @n) :$% shxTail sh) x
ReverseX d -> shapeDeltaFull d
TransposeX perm d -> case shapeDeltaFull d of
FTKX sh x -> FTKX (shxPermutePrefix perm sh) x
ReshapeX sh2 d -> case shapeDeltaFull d of
FTKX _ x -> FTKX sh2 x
MCastX sh2 d -> case shapeDeltaFull d of
FTKX sh x -> FTKX (shxCast' sh sh2) x
GatherX sh d _ -> case shapeDeltaFull d of
FTKX _ x -> FTKX sh x
CastX d -> case shapeDeltaFull d of
FTKX sh FTKScalar -> FTKX sh FTKScalar
ZipX d -> case shapeDeltaFull d of
FTKProduct (FTKX sh y) (FTKX _ z) -> FTKX sh (FTKProduct y z)
UnzipX d -> case shapeDeltaFull d of
Expand Down Expand Up @@ -832,6 +901,12 @@ lengthDelta d = case shapeDelta d of
ZSR -> error "lengthDelta: impossible pattern needlessly required"
k :$: _ -> k

shapeDeltaX :: forall target r sh.
(TensorKind r, KnownShX sh)
=> Delta target (TKX2 sh r) -> IShX sh
shapeDeltaX t = case shapeDeltaFull t of
FTKX sh _ -> sh

shapeDeltaH :: forall target.
Delta target TKUntyped -> VoidHVector
shapeDeltaH t = case shapeDeltaFull t of
Expand Down Expand Up @@ -1048,6 +1123,19 @@ evalSRuntimeSpecialized !s !c =
Just Refl -> evalRevSame @(TKS sh Float) s c
_ -> const s

evalXRuntimeSpecialized
:: forall sh r target.
(GoodScalar r, KnownShX sh, ADReadyNoLet target, ShareTensor target)
=> EvalState target -> target (ADTensorKind (TKX sh r))
-> Delta target (TKX sh r)
-> EvalState target
evalXRuntimeSpecialized !s !c =
case testEquality (typeRep @r) (typeRep @Double) of
Just Refl -> evalRevSame @(TKX sh Double) s c
_ -> case testEquality (typeRep @r) (typeRep @Float) of
Just Refl -> evalRevSame @(TKX sh Float) s c
_ -> const s

evalRev
:: forall y target.
(TensorKind y, ADReadyNoLet target, ShareTensor target)
Expand Down Expand Up @@ -1325,6 +1413,53 @@ evalRevSame !s !c = \case
evalRevSame s (dmkHVector $ cs V.// [(i, ci)]) d

IndexX{} -> error "TODO"
Sum0X d ->
evalRevSame s (xreplicate0N (shapeDeltaX d) c) d
Dot0X v vd ->
evalRevSame s (v * xreplicate0N (xshape v) c) vd
-- too slow: evalRevSame s (smap0N (* (sscalar c)) v) vd
ScatterX @_ @_ @shm @shn @shp _sh d f ->
evalRevSame s (xgather @_ @_ @shm @shn @shp (shapeDeltaX d) c f) d
AppendX d e -> case (shapeDeltaX d, shapeDeltaX e) of
(shd@(Nested.SUnknown m :$% rest), she@(Nested.SUnknown n :$% _)) ->
withSNat m $ \(SNat @m) -> withSNat n $ \(SNat @n) ->
let cShared =
tshare $ xmcast (ssxFromShape
$ Nested.SKnown (SNat @(m + n)) :$% rest) c
s2 = evalRevSame s (xmcast (ssxFromShape shd)
$ xslice (Proxy @0) (Proxy @m) cShared) d
in evalRevSame s2 (xmcast (ssxFromShape she)
$ xslice (Proxy @m) (Proxy @n) cShared) e
SliceX @_ @i @n @k d -> case tftk (stensorKind @y) c of
FTKX (_ :$% rest) x ->
evalRevSame s
(xmcast (ssxFromShape $ Nested.SKnown (SNat @(i + n + k)) :$% rest)
$ xconcat
[ constantTarget 0 (FTKX (Nested.SUnknown (valueOf @i) :$% rest) x)
, xmcast (ssxFromShape $ Nested.SUnknown (valueOf @n) :$% rest) c
, constantTarget 0 (FTKX (Nested.SUnknown (valueOf @k) :$% rest) x) ])
d
ReverseX d -> evalRevSame s (xreverse c) d
TransposeX @perm @sh2 perm d ->
withKnownShX (ssxPermutePrefix perm (knownShX @sh2)) $
permInverse perm $ \(permRev :: Permutation.Perm permR) _ ->
gcastWith (unsafeCoerceRefl
:: Permutation.PermutePrefix permR (Permutation.PermutePrefix perm sh2) :~: sh2)
$ gcastWith (unsafeCoerceRefl
:: Rank (Permutation.PermutePrefix perm sh2) :~: Rank sh2)
$ gcastWith (unsafeCoerceRefl
:: Rank permR :~: Rank perm)
$ evalRevSame s (xtranspose permRev c) d
ReshapeX _sh d ->
evalRevSame s (xreshape (shapeDeltaX d) c) d
MCastX @_ @sh sh2 d ->
withKnownShX sh2 $
evalRevSame s (xmcast (knownShX @sh) c) d
GatherX @_ @_ @shm @shn @shp _sh d f ->
evalRevSame s (xscatter @_ @_ @shm @shn @shp (shapeDeltaX d) c f) d
CastX @r1 @_ @sh d ->
evalXRuntimeSpecialized s (toADTensorKindShared (stensorKind @(TKX sh r1))
$ xcast c) d
ZipX d ->
evalRevSame s (xunzip c) d
UnzipX d ->
Expand Down Expand Up @@ -1422,6 +1557,10 @@ evalRevFromnMap s@EvalState{nMap, dMap} =
withKnownShS sh $ case DMap.lookup n dMap of
Just (RepAD c) -> evalSRuntimeSpecialized @sh @r s2 c d
Nothing -> errorMissing
STKX @sh sh (STKScalar @r _) ->
withKnownShX sh $ case DMap.lookup n dMap of
Just (RepAD c) -> evalXRuntimeSpecialized @sh @r s2 c d
Nothing -> errorMissing
_ -> case DMap.lookup n dMap of
Just (RepAD c) -> evalRev s2 c d
Nothing -> errorMissing
Expand Down Expand Up @@ -1697,6 +1836,31 @@ evalFwdSame params s = \case
-- in (s2, sfromD $ tunvector v V.! i)

IndexX d ix -> second (`xindex` ix) $ evalFwdSame params s d
Sum0X (ZeroG (FTKX _ x)) -> (s, constantTarget 0 (FTKX ZSX x))
Sum0X d -> second xsum0 $ evalFwdSame params s d
Dot0X _ ZeroG{} -> (s, xrepl ZSX 0)
Dot0X v d -> second (xdot0 v) $ evalFwdSame params s d
ScatterX @_ @_ @shm @shn @shp sh d f ->
let (s2, t) = evalFwdSame params s d
in (s2, xscatter @_ @_ @shm @shn @shp sh t f)
AppendX d e ->
let (s2, t) = evalFwdSame params s d
(s3, u) = evalFwdSame params s2 e
in (s3, xappend t u)
SliceX @_ @i d -> second (xslice (Proxy @i) Proxy) $ evalFwdSame params s d
ReverseX d -> second xreverse $ evalFwdSame params s d
TransposeX perm d -> second (xtranspose perm)
$ evalFwdSame params s d
ReshapeX sh2 d -> second (xreshape sh2) $ evalFwdSame params s d
MCastX sh2 d -> second (xmcast sh2) $ evalFwdSame params s d
GatherX @_ @_ @shm @shn @shp sh d f ->
let (s2, t) = evalFwdSame params s d
in (s2, xgather @_ @_ @shm @shn @shp sh t f)
d0@(CastX @r1 @_ @sh d)
| Dict <- lemTensorKindOfAD (stensorKind @(TKX sh r1)) ->
case sameTensorKind @(TKX sh r1) @(ADTensorKind (TKX sh r1)) of
Just Refl -> second xcast $ evalFwdSame params s d
_ -> (s, constantTarget 0 $ aDFTK $ shapeDeltaFull d0)
ZipX d -> second xzip $ evalFwdSame params s d
UnzipX d -> second xunzip $ evalFwdSame params s d

Expand Down
73 changes: 70 additions & 3 deletions src/HordeAd/Core/OpsADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NonEmpty
import Data.Maybe (fromMaybe)
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality ((:~:) (Refl))
import Data.Type.Equality (testEquality, (:~:) (Refl))
import Data.Vector.Generic qualified as V
import GHC.TypeLits (fromSNat, KnownNat, sameNat)
import Type.Reflection (typeRep)
Expand All @@ -30,6 +30,9 @@ import Data.Array.Mixed.Shape (ssxAppend, withKnownShX, ssxFromShape, ssxReplica
import Data.Array.Nested
( IxR (..)
, IxS (..)
, IxX (..)
, StaticShX(..)
, ShX (..)
, KnownShS (..)
, KnownShX (..)
, Rank
Expand Down Expand Up @@ -409,16 +412,79 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
sD t d = dD t d
sScale k = ScaleG k

xminIndex (D u _) =
let v = xminIndex u
in fromPrimalADVal v
xmaxIndex (D u _) =
let v = xmaxIndex u
in fromPrimalADVal v
xfloor (D u _) =
let v = xfloor u
in fromPrimalADVal v

xshape (D u _) = xshape u
xiota = fromPrimalADVal xiota
xindex (D u u') i =
let ix = tprimalPart (STKScalar typeRep) <$> i
in dD (xindex u ix) (IndexX u' ix)
xsum (D u u') = dD (xsum u) (SumG SNat stensorKind u')
xsum0 (D u u') = dD (xsum0 u) (Sum0X u')
xdot0 (D ue u') (D ve v') =
-- The bangs below are neccessary for GHC 9.2.7 test results to match 9.4.
let !u = tshare ue in
let !v = tshare ve
in dD (xdot0 u v) (AddG (Dot0X v u') (Dot0X u v'))
xscatter @r @shm @shn @shp sh (D u u') f =
withKnownShX (knownShX @shm `ssxAppend` knownShX @shn) $
withKnownShX (knownShX @shp `ssxAppend` knownShX @shn) $
let g x = tprimalPart (STKScalar typeRep)
<$> f (tfromPrimal (STKScalar typeRep) <$> x)
in dD (xscatter @_ @r @shm @shn @shp sh u g)
(ScatterX @_ @r @shm @shn @shp sh u' g)

xfromVector @_ @k lu = assert (length lu == valueOf @k) $ -- TODO: Move these assertions to the base instances, that is concrete and AST
dD (xfromVector $ V.map (\(D u _) -> u) lu)
(FromVectorG (SNat @k) stensorKind $ V.map (\(D _ u') -> u') lu)
-- xreplicate (D u (DeltaX u')) = dD (xreplicate u) (DeltaX $ ReplicateX u')
xreplicate _ = error "TODO"
xunravelToList (D u u') =
let lu = xunravelToList u
f i ui = dD ui (IndexX u' (fromIntegral i :.% ZIX))
in imap f lu
xreplicate (D u u') = dD (xreplicate u) (ReplicateG SNat stensorKind u')
xappend (D u u') (D v v') =
dD (xappend u v) (AppendX u' v')
xslice @_ @i i_proxy n_proxy (D u u') =
dD (xslice i_proxy n_proxy u) (SliceX @target @i u')
xreverse (D u u') = withKnownShX (ssxFromShape $ xshape u) $
dD (xreverse u) (ReverseX u')

xtranspose @_ @_ @sh perm (D u u') =
withKnownShX (ssxPermutePrefix perm (knownShX @sh)) $
dD (xtranspose perm u) (TransposeX @_ @_ @_ @target perm u')
xreshape @_ @sh sh t@(D u u') =
case testEquality (knownShX @sh) (ssxFromShape sh) of
Just Refl | sh == xshape u -> t
_ -> dD (xreshape sh u) (ReshapeX sh u')
xbuild1 @r @n @sh f = case sameNat (Proxy @n) (Proxy @0) of
Just Refl -> case stensorKind @r of
STKScalar{} -> case knownShX @sh of
ZKX -> xconcrete $ Nested.memptyArray ZSX
_ -> error "xbuild1: empty nested array"
_ -> error "xbuild1: shape ambiguity"
Nothing -> xfromList $ NonEmpty.map (f . fromInteger)
$ 0 :| [1 .. valueOf @n - 1]
-- element-wise (POPL) version
xmcast sh2 (D u u') = withKnownShX sh2 $
dD (xmcast sh2 u) (MCastX sh2 u')
xgather @r @shm @shn @shp sh (D u u') f =
withKnownShX (ssxFromShape sh) $
withKnownShX (knownShX @shp `ssxAppend` knownShX @shn) $
let g x = tprimalPart (STKScalar typeRep)
<$> f (tfromPrimal (STKScalar typeRep) <$> x)
in dD (xgather @_ @r @shm @shn @shp sh u g) (GatherX @_ @r @shm @shn @shp sh u' g)
xcast (D u u') = dD (xcast u) (CastX u')
xfromIntegral (D u _) =
let v = xfromIntegral u
in fromPrimalADVal v
xzip (D u u') = dD (xzip u) (ZipX u')
xunzip (D u u') = dD (xunzip u) (UnzipX u')
xtoScalar (D t d) = dDnotShared (xtoScalar t) (ToScalarG $ SFromX d)
Expand All @@ -427,6 +493,7 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
xprimalPart (D u _) = u
xdualPart (D _ u') = u'
xD t d = dD t d
xScale k = ScaleG k

kfloor (D u _) =
let v = kfloor u
Expand Down
Loading

0 comments on commit 1122d9b

Please sign in to comment.