Skip to content

Commit

Permalink
Make BuildTensorKind(TKScalar) saner
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 11, 2025
1 parent 74291c8 commit 3c81602
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 67 deletions.
4 changes: 2 additions & 2 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,9 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
=> OpCodeNum2 -> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
AstR1 :: (RealFloatF r, GoodScalar r)
AstR1 :: (RealFloatF r, Nested.FloatElt r, GoodScalar r)
=> OpCode1 -> AstTensor ms s (TKScalar r) -> AstTensor ms s (TKScalar r)
AstR2 :: (RealFloatF r, GoodScalar r)
AstR2 :: (RealFloatF r, Nested.FloatElt r, GoodScalar r)
=> OpCode2 -> AstTensor ms s (TKScalar r) -> AstTensor ms s (TKScalar r)
-> AstTensor ms s (TKScalar r)
AstI2 :: (IntegralF r, GoodScalar r)
Expand Down
11 changes: 9 additions & 2 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,9 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 @shm1 i1 rest1)
let ftk = FTKS knownShS x
in fromPrimal $ AstConcrete ftk (constantTarget def ftk)
Ast.AstReplicate _ (STKS sh _) v -> withKnownShS sh $ astIndex v rest1
Ast.AstReplicate _ STKScalar{} v | ZIS <- rest1 -> astFromScalar v
Ast.AstBuild1 @y2 _snat (var2, v) -> case stensorKind @y2 of
STKScalar{} | ZIS <- rest1 -> astFromScalar $ astLet var2 i1 v
STKS sh _ ->
withKnownShS sh $
withKnownShS (knownShS @shm1 `shsAppend` knownShS @shn) $
Expand Down Expand Up @@ -863,12 +865,13 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 @shm1 i1 rest1)
FTKS _ x ->
let ftk = FTKS knownShS x
in fromPrimal $ AstConcrete ftk (constantTarget def ftk)
Ast.AstFromVector{} | ZIS <- rest1 -> -- normal form
Ast.AstFromVector{} | ZIS <- rest1 -> -- normal form, STKScalar case included
Ast.AstIndexS v0 ix
Ast.AstFromVector @y2 snat l | STKS{} <- stensorKind @y2 ->
shareIx rest1 $ \ !ix2 ->
Ast.AstIndexS @'[in1] @shn (astFromVector snat $ V.map (`astIndexRec` ix2) l)
(i1 :.$ ZIS)
Ast.AstFromVector{} -> error "astIndexKnobsS: impossible case"
Ast.AstAppendS @m u v ->
let ulen = AstConcrete FTKScalar $ RepN $ valueOf @m
ix1 = i1 :.$ rest1
Expand Down Expand Up @@ -1498,6 +1501,8 @@ astGatherKnobsS knobs v0 (vars0, ix0) =
in fromPrimal $ AstConcrete ftk (constantTarget def ftk)
Ast.AstReplicate _ STKS{} v ->
astGather @shm' @shn' @shp1' v (vars4, rest4)
Ast.AstReplicate _ STKScalar{} v | ZIS <- rest4 ->
astGather @shm' @shn' @shp1' (astFromScalar v) (vars4, rest4)
Ast.AstBuild1{} -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4)
Ast.AstLet var u v ->
astLet var u (astGatherCase @shm' @shn' @shp' v (vars4, ix4))
Expand Down Expand Up @@ -1594,7 +1599,8 @@ astGatherKnobsS knobs v0 (vars0, ix0) =
FTKS _ x ->
let ftk = FTKS knownShS x
in fromPrimal $ AstConcrete ftk (constantTarget def ftk)
Ast.AstFromVector{} | gatherFromNFS vars4 ix4 -> -- normal form
Ast.AstFromVector{} | gatherFromNFS vars4 ix4 -> -- normal form,
-- STKScalar case included
Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4)
Ast.AstFromVector @y2 snat l | STKS{} <- stensorKind @y2 ->
-- Term rest4 is duplicated without sharing and we can't help it,
Expand All @@ -1611,6 +1617,7 @@ astGatherKnobsS knobs v0 (vars0, ix0) =
in astGather @shm' @shn' @(p1' ': shm')
(astFromVector snat $ V.map f l)
(varsFresh, i5 :.$ IxS ixFresh)
Ast.AstFromVector{} -> error "astGatherCase: impossible case"
Ast.AstAppendS @m u v ->
let ulen = AstConcrete FTKScalar $ RepN $ valueOf @m
iu = simplifyAstInt (AstN2 MinusOp i4 ulen)
Expand Down
50 changes: 17 additions & 33 deletions src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import Prelude

import Control.Exception.Assert.Sugar
import Control.Monad (when)
import Data.Default
import Data.Functor.Const
import Data.IntMap.Strict qualified as IM
import Data.IORef
Expand Down Expand Up @@ -42,7 +41,6 @@ import Data.Array.Nested
, ShS (..)
, type (++)
)
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Internal.Shape
( shCvtSX
, shsAppend
Expand Down Expand Up @@ -161,26 +159,13 @@ build1V snat@SNat (var, v0) =
traceRule | Dict <- lemTensorKindOfBuild snat (stensorKind @y) =
mkTraceRule "build1V" bv v0 1
in case v0 of
Ast.AstFromScalar v2@(Ast.AstVar _ var2) -- TODO: make compositional
| varNameToAstVarId var2 == varNameToAstVarId var -> traceRule $
case isTensorInt v2 of
Just Refl -> fromPrimal @s $ Ast.AstIotaS @k
-- results in smaller terms than AstSlice(AstIota), because
-- not turned into a concrete array so early
_ -> error "build1V: build variable is not an index variable"
Ast.AstFromScalar{} -> case astNonIndexStep v0 of
Ast.AstFromScalar{} -> -- let's hope this doesn't oscillate
error $ "build1V: AstFromScalar: building over scalars is undefined: "
++ show v0
v1 -> build1VOccurenceUnknown snat (var, v1) -- last ditch effort
Ast.AstToScalar{} ->
error $ "build1V: AstToScalar: building over scalars is undefined: "
++ show v0
Ast.AstFromScalar t -> build1V snat (var, t)
Ast.AstToScalar t -> build1V snat (var, t)
Ast.AstPair @x @z t1 t2
| Dict <- lemTensorKindOfBuild snat (stensorKind @x)
, Dict <- lemTensorKindOfBuild snat (stensorKind @z) -> traceRule $
astPair (build1VOccurenceUnknown snat (var, t1))
(build1VOccurenceUnknown snat (var, t2))
(build1VOccurenceUnknown snat (var, t2))
Ast.AstProject1 @_ @z t
| Dict <- lemTensorKindOfBuild snat (stensorKind @z)
, Dict <- lemTensorKindOfBuild snat (stensorKind @y) -> traceRule $
Expand All @@ -192,8 +177,7 @@ build1V snat@SNat (var, v0) =
Ast.AstVar _ var2 -> traceRule $
if varNameToAstVarId var2 == varNameToAstVarId var
then case isTensorInt v0 of
Just Refl -> Ast.AstToScalar $ Ast.AstConcrete (FTKS ZSS FTKScalar)
$ RepN $ Nested.sscalar def -- TODO: ???
Just Refl -> fromPrimal @s $ Ast.AstIotaS @k
_ -> error "build1V: build variable is not an index variable"
else error "build1V: AstVar can't contain other free variables"
Ast.AstPrimalPart v
Expand Down Expand Up @@ -247,27 +231,27 @@ build1V snat@SNat (var, v0) =
$ map (\v -> build1VOccurenceUnknown snat (var, v)) args

Ast.AstN1 opCode u -> traceRule $
Ast.AstN1 opCode (build1V snat (var, u))
Ast.AstN1S opCode (build1V snat (var, u))
Ast.AstN2 opCode u v -> traceRule $
Ast.AstN2 opCode (build1VOccurenceUnknown snat (var, u))
(build1VOccurenceUnknown snat (var, v))
Ast.AstN2S opCode (build1VOccurenceUnknown snat (var, u))
(build1VOccurenceUnknown snat (var, v))
-- we permit duplicated bindings, because they can't easily
-- be substituted into one another unlike. e.g., inside a let,
-- which may get inlined
Ast.AstR1 opCode u -> traceRule $
Ast.AstR1 opCode (build1V snat (var, u))
Ast.AstR1S opCode (build1V snat (var, u))
Ast.AstR2 opCode u v -> traceRule $
Ast.AstR2 opCode (build1VOccurenceUnknown snat (var, u))
(build1VOccurenceUnknown snat (var, v))
Ast.AstR2S opCode (build1VOccurenceUnknown snat (var, u))
(build1VOccurenceUnknown snat (var, v))
Ast.AstI2 opCode u v -> traceRule $
Ast.AstI2 opCode (build1VOccurenceUnknown snat (var, u))
(build1VOccurenceUnknown snat (var, v))
Ast.AstI2S opCode (build1VOccurenceUnknown snat (var, u))
(build1VOccurenceUnknown snat (var, v))
Ast.AstFloor v -> traceRule $
Ast.AstFloor $ build1V snat (var, v)
Ast.AstFloorS $ build1V snat (var, v)
Ast.AstCast v -> traceRule $
astCast $ build1V snat (var, v)
astCastS $ build1V snat (var, v)
Ast.AstFromIntegral v -> traceRule $
astFromIntegral $ build1V snat (var, v)
astFromIntegralS $ build1V snat (var, v)

Ast.AstN1R opCode u -> traceRule $
Ast.AstN1R opCode (build1V snat (var, u))
Expand Down Expand Up @@ -740,7 +724,7 @@ astTrBuild
-> AstTensor AstMethodLet s (BuildTensorKind k1 (BuildTensorKind k2 y))
-> AstTensor AstMethodLet s (BuildTensorKind k2 (BuildTensorKind k1 y))
astTrBuild stk t = case stk of
STKScalar{} -> t
STKScalar{} -> astTrS t
STKR SNat stk1 | Dict <- lemTensorKindOfSTK stk1 -> astTr t
STKS sh stk1 | Dict <- lemTensorKindOfSTK stk1 -> withKnownShS sh $ astTrS t
STKX sh stk1 | Dict <- lemTensorKindOfSTK stk1 -> withKnownShX sh $ astTrX t
Expand Down Expand Up @@ -769,7 +753,7 @@ astIndexBuild :: forall y k s. AstSpan s
-> AstInt AstMethodLet
-> AstTensor AstMethodLet s y
astIndexBuild snat@SNat stk u i = case stk of
STKScalar{} -> u
STKScalar{} -> Ast.AstToScalar $ astIndexStepS u (i :.$ ZIS)
STKR SNat x | Dict <- lemTensorKindOfSTK x ->
astIndexStep u (i :.: ZIR)
STKS sh x | Dict <- lemTensorKindOfSTK x ->
Expand Down
6 changes: 3 additions & 3 deletions src/HordeAd/Core/CarriersAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ instance (GoodScalar r, IntegralF r, AstSpan s)
quotF = AstI2 QuotOp
remF = AstI2 RemOp

instance (GoodScalar r, RealFloatF r, AstSpan s)
instance (GoodScalar r, RealFloatF r, Nested.FloatElt r, AstSpan s)
=> Fractional (AstTensor ms s (TKScalar r)) where
u / v = AstR2 DivideOp u v
recip = AstR1 RecipOp
fromRational r = fromPrimal . AstConcrete FTKScalar . fromRational $ r

instance (GoodScalar r, RealFloatF r, AstSpan s)
instance (GoodScalar r, RealFloatF r, Nested.FloatElt r, AstSpan s)
=> Floating (AstTensor ms s (TKScalar r)) where
pi = error "pi not defined for ranked tensors"
exp = AstR1 ExpOp
Expand All @@ -128,7 +128,7 @@ instance (GoodScalar r, RealFloatF r, AstSpan s)
acosh = AstR1 AcoshOp
atanh = AstR1 AtanhOp

instance (GoodScalar r, RealFloatF r, AstSpan s)
instance (GoodScalar r, RealFloatF r, Nested.FloatElt r, AstSpan s)
=> RealFloatF (AstTensor ms s (TKScalar r)) where
atan2F = AstR2 Atan2Op

Expand Down
9 changes: 3 additions & 6 deletions src/HordeAd/Core/OpsADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -564,12 +564,9 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
dshape (D u _) = dshape u
tftk stk (D u _) = tftk stk u
-- Bangs are for the proper order of sharing stamps.
tcond !stk !b !u !v = case stk of
STKScalar{} -> -- TODO: eliminate this special case
rtoScalar $ ifF b (rfromScalar u) (rfromScalar v)
_ ->
let uv = tfromVector (SNat @2) stk (V.fromList [u, v])
in tindexBuildShare (SNat @2) stk uv (ifF b 0 1)
tcond !stk !b !u !v =
let uv = tfromVector (SNat @2) stk (V.fromList [u, v])
in tindexBuildShare (SNat @2) stk uv (ifF b 0 1)
tfromPrimal stk t | Dict <- lemTensorKindOfSTK stk = fromPrimalADVal t
tprimalPart _stk (D u _) = u
tdualPart _stk (D _ u') = u'
Expand Down
8 changes: 2 additions & 6 deletions src/HordeAd/Core/OpsConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import GHC.TypeLits
import Numeric.LinearAlgebra (Numeric)
import Numeric.LinearAlgebra qualified as LA
import System.Random
import Type.Reflection (typeRep)
import Unsafe.Coerce (unsafeCoerce)

import Data.Array.Mixed.Internal.Arith qualified as Mixed.Internal.Arith
Expand Down Expand Up @@ -646,8 +645,7 @@ ravel :: forall k y. TensorKind y
=> SNat k -> [RepN y]
-> RepN (BuildTensorKind k y)
ravel k@SNat t = case stensorKind @y of
STKScalar rep | Just Refl <- testEquality rep (typeRep @Z0) -> RepN Z0
STKScalar _ -> error "ravel: scalar"
STKScalar{} -> sfromList $ sfromScalar <$> NonEmpty.fromList t
STKR SNat x | Dict <- lemTensorKindOfSTK x ->
rfromList $ NonEmpty.fromList t
STKS sh x | Dict <- lemTensorKindOfSTK x ->
Expand All @@ -667,9 +665,7 @@ unravel :: forall k y. TensorKind y
=> SNat k -> RepN (BuildTensorKind k y)
-> [RepN y]
unravel k@SNat t = case stensorKind @y of
STKScalar rep | Just Refl <- testEquality rep (typeRep @Z0) ->
replicate (sNatValue k) (RepN Z0)
STKScalar _ -> error "unravel: scalar"
STKScalar{} -> map stoScalar $ sunravelToList t
STKR SNat x | Dict <- lemTensorKindOfSTK x ->
runravelToList t
STKS sh x | Dict <- lemTensorKindOfSTK x ->
Expand Down
23 changes: 11 additions & 12 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ module HordeAd.Core.TensorClass
import Prelude

import Data.Array.Mixed.Types (unsafeCoerceRefl)
import Data.Default
import Data.Kind (Constraint, Type)
import Data.List (transpose)
import Data.List.NonEmpty (NonEmpty)
Expand Down Expand Up @@ -135,7 +134,7 @@ class LetTensor (target :: Target) where
-> Data.Vector.Vector (target y)
-> target (BuildTensorKind k y)
tfromVector snat@SNat stk v = case stk of
STKScalar{} -> error "tfromVector: vector of scalars"
STKScalar{} -> sfromVector $ V.map sfromScalar v
STKR SNat x | Dict <- lemTensorKindOfSTK x ->
rfromVector v
STKS sh x | Dict <- lemTensorKindOfSTK x ->
Expand Down Expand Up @@ -164,7 +163,7 @@ class LetTensor (target :: Target) where
-> target (BuildTensorKind k z)
-> target z
tsum snat@SNat stk u = case stk of
STKScalar{} -> u
STKScalar{} -> stoScalar $ ssum u
STKR SNat x | Dict <- lemTensorKindOfSTK x ->
rsum u
STKS sh x | Dict <- lemTensorKindOfSTK x ->
Expand All @@ -185,7 +184,7 @@ class LetTensor (target :: Target) where
-> target z
-> target (BuildTensorKind k z)
treplicate snat@SNat stk u = case stk of
STKScalar{} -> u
STKScalar{} -> sreplicate $ sfromScalar u
STKR SNat x | Dict <- lemTensorKindOfSTK x -> rreplicate (sNatValue snat) u
STKS sh x | Dict <- lemTensorKindOfSTK x -> withKnownShS sh $ sreplicate u
STKX sh x | Dict <- lemTensorKindOfSTK x -> withKnownShX sh $ xreplicate u
Expand Down Expand Up @@ -214,7 +213,7 @@ class ShareTensor (target :: Target) where
-> Data.Vector.Vector (target y)
-> target (BuildTensorKind k y)
tfromVectorShare snat@SNat stk v = case stk of
STKScalar{} -> error "tfromVectorShare: vector of scalars"
STKScalar{} -> sfromVector $ V.map sfromScalar v
STKR SNat x | Dict <- lemTensorKindOfSTK x ->
rfromVector v
STKS sh x | Dict <- lemTensorKindOfSTK x ->
Expand All @@ -235,7 +234,7 @@ class ShareTensor (target :: Target) where
-> target (BuildTensorKind k y)
-> [target y]
tunravelToListShare snat@SNat stk u = case stk of
STKScalar{} -> error "tunravelToList: scalar"
STKScalar{} -> map stoScalar $ sunravelToList u
STKR SNat x | Dict <- lemTensorKindOfSTK x ->
runravelToList u
STKS sh x | Dict <- lemTensorKindOfSTK x ->
Expand All @@ -258,7 +257,7 @@ class ShareTensor (target :: Target) where
-> target (BuildTensorKind k z)
-> target z
tsumShare snat@SNat stk u = case stk of
STKScalar{} -> u
STKScalar{} -> stoScalar $ ssum u
STKR SNat x | Dict <- lemTensorKindOfSTK x ->
rsum u
STKS sh x | Dict <- lemTensorKindOfSTK x ->
Expand All @@ -279,7 +278,7 @@ class ShareTensor (target :: Target) where
-> target z
-> target (BuildTensorKind k z)
treplicateShare snat@SNat stk u = case stk of
STKScalar{} -> u
STKScalar{} -> sreplicate $ sfromScalar u
STKR SNat x | Dict <- lemTensorKindOfSTK x -> rreplicate (sNatValue snat) u
STKS sh x | Dict <- lemTensorKindOfSTK x -> withKnownShS sh $ sreplicate u
STKX sh x | Dict <- lemTensorKindOfSTK x -> withKnownShX sh $ xreplicate u
Expand All @@ -300,7 +299,7 @@ class ShareTensor (target :: Target) where
-> target (BuildTensorKind k z) -> IntOf target
-> target z
tindexBuildShare snat@SNat stk u i = case stk of
STKScalar{} -> u
STKScalar{} -> stoScalar $ sindex u (i :.$ ZIS)
STKR SNat x | Dict <- lemTensorKindOfSTK x ->
rindex u (i :.: ZIR)
STKS sh x | Dict <- lemTensorKindOfSTK x ->
Expand All @@ -324,8 +323,8 @@ class ShareTensor (target :: Target) where
class ( Num (IntOf target)
, IntegralF (IntOf target)
, TensorSupports Num Num target
, TensorSupports RealFloatF Floating target
, TensorSupports RealFloatF RealFloatF target
, TensorSupports RealFloatAndFloatElt Floating target
, TensorSupports RealFloatAndFloatElt RealFloatF target
, TensorSupports IntegralF IntegralF target
, TensorSupportsR Num Num target
, TensorSupportsR RealFloatAndFloatElt Floating target
Expand Down Expand Up @@ -1428,7 +1427,7 @@ class ( Num (IntOf target)
let replStk :: STensorKindType z -> (IntOf target -> target z)
-> target (BuildTensorKind k z)
replStk stk g = case stk of
STKScalar{} -> rtoScalar $ rscalar def -- TODO: ???
STKScalar{} -> sbuild1 (sfromScalar . g)
STKR SNat x | Dict <- lemTensorKindOfSTK x ->
rbuild1 (sNatValue snat) g
STKS sh x | Dict <- lemTensorKindOfSTK x ->
Expand Down
4 changes: 2 additions & 2 deletions src/HordeAd/Core/TensorKind.hs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ ftkToStk = \case

buildFTK :: SNat k -> FullTensorKind y -> FullTensorKind (BuildTensorKind k y)
buildFTK snat@SNat = \case
FTKScalar -> FTKScalar
FTKScalar -> FTKS (snat :$$ ZSS) FTKScalar
FTKR sh x -> FTKR (sNatValue snat :$: sh) x
FTKS sh x -> FTKS (snat :$$ sh) x
FTKX sh x -> FTKX (SKnown snat :$% sh) x
Expand All @@ -220,7 +220,7 @@ razeFTK :: forall y k.
-> FullTensorKind (BuildTensorKind k y)
-> FullTensorKind y
razeFTK snat@SNat stk ftk = case (stk, ftk) of
(STKScalar{}, FTKScalar) -> FTKScalar
(STKScalar{}, FTKS (_ :$$ ZSS) FTKScalar) -> FTKScalar
(STKR{}, FTKR (_ :$: sh) x) -> FTKR sh x
(STKR{}, FTKR ZSR _) -> error "razeFTK: impossible built tensor kind"
(STKS{}, FTKS (_ :$$ sh) x) -> FTKS sh x
Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ instance IfDifferentiable Float where
ifDifferentiable ra _ = ra

type family BuildTensorKind k tk where
BuildTensorKind k (TKScalar r) = TKScalar r -- TODO: say why on Earth
BuildTensorKind k (TKScalar r) = TKS '[k] r
BuildTensorKind k (TKR2 n r) = TKR2 (1 + n) r
BuildTensorKind k (TKS2 sh r) = TKS2 (k : sh) r
BuildTensorKind k (TKX2 sh r) = TKX2 (Just k : sh) r
Expand Down

0 comments on commit 3c81602

Please sign in to comment.