Skip to content

Commit

Permalink
Add backend-specific shaped AST constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 1, 2025
1 parent 7f685fa commit 65ce664
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 20 deletions.
25 changes: 25 additions & 0 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import Data.Array.Nested
, MapJust
, Rank
, Replicate
, ShS (..)
, type (++)
)
import Data.Array.Nested qualified as Nested
Expand Down Expand Up @@ -682,6 +683,30 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
=> AstTensor ms s (TKR 2 r)
-> AstTensor ms s (TKR 2 r)
-> AstTensor ms s (TKR 2 r)
AstReplicate0NS :: ShS sh -> STensorKindType x
-> AstTensor ms s (TKS2 '[] x)
-> AstTensor ms s (TKS2 sh x)
AstSum0S :: ShS sh -> STensorKindType x
-> AstTensor ms s (TKS2 sh x)
-> AstTensor ms s (TKS2 '[] x)
AstDot0S :: GoodScalar r
=> ShS sh
-> AstTensor ms s (TKS sh r) -> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS '[] r)
AstDot1InS :: GoodScalar r
=> SNat m -> SNat n
-> AstTensor ms s (TKS '[m, n] r) -> AstTensor ms s (TKS '[m, n] r)
-> AstTensor ms s (TKS '[m] r)
AstMatvecmulS :: GoodScalar r
=> SNat m -> SNat n
-> AstTensor ms s (TKS '[m, n] r)
-> AstTensor ms s (TKS '[n] r)
-> AstTensor ms s (TKS '[m] r)
AstMatmul2S :: (GoodScalar r, Numeric r)
=> SNat m -> SNat n -> SNat p
-> AstTensor ms s (TKS '[m, n] r)
-> AstTensor ms s (TKS '[n, p] r)
-> AstTensor ms s (TKS '[m, p] r)

deriving instance Show (AstTensor ms s y)

Expand Down
40 changes: 40 additions & 0 deletions src/HordeAd/Core/AstInline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,26 @@ inlineAst memo v0 = case v0 of
let (memo2, u2) = inlineAst memo u
(memo3, v3) = inlineAst memo2 v
in (memo3, Ast.AstMatmul2R u2 v3)
Ast.AstReplicate0NS sh stk v ->
second (Ast.AstReplicate0NS sh stk) (inlineAst memo v)
Ast.AstSum0S sh stk v ->
second (Ast.AstSum0S sh stk) (inlineAst memo v)
Ast.AstDot0S sh u v ->
let (memo2, u2) = inlineAst memo u
(memo3, v3) = inlineAst memo2 v
in (memo3, Ast.AstDot0S sh u2 v3)
Ast.AstDot1InS m n u v ->
let (memo2, u2) = inlineAst memo u
(memo3, v3) = inlineAst memo2 v
in (memo3, Ast.AstDot1InS m n u2 v3)
Ast.AstMatvecmulS m n u v ->
let (memo2, u2) = inlineAst memo u
(memo3, v3) = inlineAst memo2 v
in (memo3, Ast.AstMatvecmulS m n u2 v3)
Ast.AstMatmul2S m n p u v ->
let (memo2, u2) = inlineAst memo u
(memo3, v3) = inlineAst memo2 v
in (memo3, Ast.AstMatmul2S m n p u2 v3)

inlineAstDynamic
:: AstSpan s
Expand Down Expand Up @@ -806,6 +826,26 @@ unshareAst memo = \case
let (memo2, u2) = unshareAst memo u
(memo3, v3) = unshareAst memo2 v
in (memo3, Ast.AstMatmul2R u2 v3)
Ast.AstReplicate0NS sh stk v ->
second (Ast.AstReplicate0NS sh stk) (unshareAst memo v)
Ast.AstSum0S sh stk v ->
second (Ast.AstSum0S sh stk) (unshareAst memo v)
Ast.AstDot0S sh u v ->
let (memo2, u2) = unshareAst memo u
(memo3, v3) = unshareAst memo2 v
in (memo3, Ast.AstDot0S sh u2 v3)
Ast.AstDot1InS m n u v ->
let (memo2, u2) = unshareAst memo u
(memo3, v3) = unshareAst memo2 v
in (memo3, Ast.AstDot1InS m n u2 v3)
Ast.AstMatvecmulS m n u v ->
let (memo2, u2) = unshareAst memo u
(memo3, v3) = unshareAst memo2 v
in (memo3, Ast.AstMatvecmulS m n u2 v3)
Ast.AstMatmul2S m n p u v ->
let (memo2, u2) = unshareAst memo u
(memo3, v3) = unshareAst memo2 v
in (memo3, Ast.AstMatmul2S m n p u2 v3)

unshareAstDynamic
:: AstSpan s
Expand Down
19 changes: 18 additions & 1 deletion src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import Data.Array.Mixed.Shape (pattern (:.%), pattern ZIX, ssxAppend)
import Data.Array.Nested
(IxR (..), KnownShS (..), KnownShX (..), ListR (..), ListS (..), ShR (..))
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Internal.Shape (shrRank, shsAppend)
import Data.Array.Nested.Internal.Shape (shrRank, shsAppend, shsProduct)

import HordeAd.Core.Ast
import HordeAd.Core.AstEnv
Expand Down Expand Up @@ -728,6 +728,23 @@ interpretAst !env = \case
rmatvecmul (interpretAst env u) (interpretAst env v)
AstMatmul2R u v ->
rmatmul2 (interpretAst env u) (interpretAst env v)
AstReplicate0NS sh stk v | Dict <- lemTensorKindOfSTK stk
, SNat <- shsProduct sh ->
withKnownShS sh $
sreplicate0N (interpretAst env v)
AstSum0S sh stk v | Dict <- lemTensorKindOfSTK stk
, SNat <- shsProduct sh ->
withKnownShS sh $
ssum0 (interpretAst env v)
AstDot0S sh u v | SNat <- shsProduct sh ->
withKnownShS sh $
sdot0 (interpretAst env u) (interpretAst env v)
AstDot1InS SNat n@SNat u v ->
sdot1In n (interpretAst env u) (interpretAst env v)
AstMatvecmulS SNat SNat u v ->
smatvecmul (interpretAst env u) (interpretAst env v)
AstMatmul2S SNat SNat SNat u v ->
smatmul2 (interpretAst env u) (interpretAst env v)

interpretAstDynamic
:: forall target s. (ADReady target, AstSpan s)
Expand Down
22 changes: 22 additions & 0 deletions src/HordeAd/Core/AstPrettyPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,28 @@ printAstAux cfg d = \case
. printAst cfg 11 v
AstMatmul2R u v ->
printPrefixOp printAst cfg d "rmatmul2" [u, v]
AstReplicate0NS _sh stk v | Dict <- lemTensorKindOfSTK stk ->
printPrefixOp printAst cfg d "sreplicate0N" [v]
AstSum0S sh stk v | Dict <- lemTensorKindOfSTK stk ->
withKnownShS sh $
printPrefixOp printAst cfg d "ssum0" [v]
AstDot0S sh u v ->
withKnownShS sh $
printPrefixOp printAst cfg d "sdot0" [u, v]
AstDot1InS SNat SNat u v ->
printPrefixOp printAst cfg d "ssdot1In" [u, v]
AstMatvecmulS SNat SNat u v ->
showParen (d > 10)
$ showString "smatvecmul "
. printAst cfg 11 u
. showString " "
. printAst cfg 11 v
AstMatmul2S SNat SNat SNat u v ->
showParen (d > 10)
$ showString "smatmul2 "
. printAst cfg 11 u
. showString " "
. printAst cfg 11 v
_ -> error "TODO"

-- Differs from standard only in the space after comma.
Expand Down
94 changes: 81 additions & 13 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,12 @@ astNonIndexStep t = case t of
Ast.AstDot1InR{} -> t
Ast.AstMatvecmulR{} -> t
Ast.AstMatmul2R{} -> t
Ast.AstReplicate0NS{} -> t
Ast.AstSum0S{} -> t
Ast.AstDot0S{} -> t
Ast.AstDot1InS{} -> t
Ast.AstMatvecmulS{} -> t
Ast.AstMatmul2S{} -> t

_ -> t -- TODO

Expand Down Expand Up @@ -942,6 +948,14 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 @shm1 i1 rest1)

Ast.AstApply{} -> Ast.AstIndexS v0 ix

-- The below should not appear here unless via wacky tests.
Ast.AstReplicate0NS{} -> Ast.AstIndexS v0 ix
-- impossible: Ast.AstSum0S{} -> Ast.AstIndexS v0 ix
-- impossible: Ast.AstDot0S{} -> Ast.AstIndexS v0 ix
Ast.AstDot1InS{} -> Ast.AstIndexS v0 ix
Ast.AstMatvecmulS{} -> Ast.AstIndexS v0 ix
Ast.AstMatmul2S{} -> Ast.AstIndexS v0 ix

-- TODO: compared to tletIx, it adds many lets, not one, but does not
-- create other (and non-simplified!) big terms and also uses astIsSmall,
-- so it's probably more efficient. Use this instead of tletIx
Expand Down Expand Up @@ -1793,19 +1807,13 @@ astGatherKnobsS knobs v0 (vars0, ix0) =

Ast.AstApply{} -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4)

{- TODO: is this beneficial?
AstGatherS @sh2 @p @sh @r AstIotaS (vars, i :.$ ZIS) ->
gcastWith (unsafeCoerceRefl :: Take (Rank sh2) sh2 :~: sh2)
$ gcastWith (unsafeCoerceRefl :: Drop (Rank sh2) sh2 :~: '[])
$ gcastWith (unsafeCoerceRefl :: Drop p sh :~: '[])
$ gcastWith (unsafeCoerceRefl :: sh2 :~: sh2 ++ Drop p sh)
-- transitivity of type equality doesn't work, by design,
-- so this direct cast is needed instead of more basic laws
$ sbuild @target @r @(Rank sh2)
(interpretLambdaIndexS
interpretAst env
(vars, fromPrimal @s $ AstFromIntegralS $ AstFromScalar i))
-}
-- The below should not appear here unless via wacky tests.
Ast.AstReplicate0NS{} -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4)
-- Ast.AstSum0S{} -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4)
-- Ast.AstDot0S{} -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4)
Ast.AstDot1InS{} -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4)
Ast.AstMatvecmulS{} -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4)
Ast.AstMatmul2S{} -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4)

gatherFromNFS :: forall shm n shp. KnownShS shp
=> AstVarListS shm -> AstIxS AstMethodLet (n ': shp) -> Bool
Expand Down Expand Up @@ -3020,6 +3028,12 @@ astPrimalPart t = case t of
Ast.AstDot1InR{} -> Ast.AstPrimalPart t
Ast.AstMatvecmulR{} -> Ast.AstPrimalPart t
Ast.AstMatmul2R{} -> Ast.AstPrimalPart t
Ast.AstReplicate0NS{} -> Ast.AstPrimalPart t
Ast.AstSum0S{} -> Ast.AstPrimalPart t
Ast.AstDot0S{} -> Ast.AstPrimalPart t
Ast.AstDot1InS{} -> Ast.AstPrimalPart t
Ast.AstMatvecmulS{} -> Ast.AstPrimalPart t
Ast.AstMatmul2S{} -> Ast.AstPrimalPart t

_ -> error "TODO"

Expand Down Expand Up @@ -3142,6 +3156,12 @@ astDualPart t = case t of
Ast.AstDot1InR{} -> Ast.AstDualPart t
Ast.AstMatvecmulR{} -> Ast.AstDualPart t
Ast.AstMatmul2R{} -> Ast.AstDualPart t
Ast.AstReplicate0NS{} -> Ast.AstDualPart t
Ast.AstSum0S{} -> Ast.AstDualPart t
Ast.AstDot0S{} -> Ast.AstDualPart t
Ast.AstDot1InS{} -> Ast.AstDualPart t
Ast.AstMatvecmulS{} -> Ast.AstDualPart t
Ast.AstMatmul2S{} -> Ast.AstDualPart t

_ -> error "TODO"

Expand Down Expand Up @@ -3538,6 +3558,12 @@ expandAst t = case t of
Ast.AstDot1InR{} -> t
Ast.AstMatvecmulR{} -> t
Ast.AstMatmul2R{} -> t
Ast.AstReplicate0NS{} -> t
Ast.AstSum0S{} -> t
Ast.AstDot0S{} -> t
Ast.AstDot1InS{} -> t
Ast.AstMatvecmulS{} -> t
Ast.AstMatmul2S{} -> t

_ -> error "TODO"

Expand Down Expand Up @@ -3750,6 +3776,12 @@ simplifyAst t = case t of
Ast.AstDot1InR{} -> t
Ast.AstMatvecmulR{} -> t
Ast.AstMatmul2R{} -> t
Ast.AstReplicate0NS{} -> t
Ast.AstSum0S{} -> t
Ast.AstDot0S{} -> t
Ast.AstDot1InS{} -> t
Ast.AstMatvecmulS{} -> t
Ast.AstMatmul2S{} -> t

_ -> error "TODO"

Expand Down Expand Up @@ -4181,6 +4213,12 @@ contractAst t = case t of
Ast.AstDot1InR{} -> t
Ast.AstMatvecmulR{} -> t
Ast.AstMatmul2R{} -> t
Ast.AstReplicate0NS{} -> t
Ast.AstSum0S{} -> t
Ast.AstDot0S{} -> t
Ast.AstDot1InS{} -> t
Ast.AstMatvecmulS{} -> t
Ast.AstMatmul2S{} -> t

_ -> error "TODO"

Expand Down Expand Up @@ -4815,6 +4853,36 @@ substitute1Ast i var v1 = case v1 of
in if isJust mu || isJust mv
then Just $ Ast.AstMatmul2R (fromMaybe u mu) (fromMaybe v mv)
else Nothing
Ast.AstReplicate0NS sh stk v | Dict <- lemTensorKindOfSTK stk ->
Ast.AstReplicate0NS sh stk <$> substitute1Ast i var v
Ast.AstSum0S sh stk v | Dict <- lemTensorKindOfSTK stk ->
withKnownShS sh $
Ast.AstSum0S sh stk <$> substitute1Ast i var v
Ast.AstDot0S sh u v ->
withKnownShS sh $
let mu = substitute1Ast i var u
mv = substitute1Ast i var v
in if isJust mu || isJust mv
then Just $ Ast.AstDot0S sh (fromMaybe u mu) (fromMaybe v mv)
else Nothing
Ast.AstDot1InS m@SNat n@SNat u v ->
let mu = substitute1Ast i var u
mv = substitute1Ast i var v
in if isJust mu || isJust mv
then Just $ Ast.AstDot1InS m n(fromMaybe u mu) (fromMaybe v mv)
else Nothing
Ast.AstMatvecmulS m@SNat n@SNat u v ->
let mu = substitute1Ast i var u
mv = substitute1Ast i var v
in if isJust mu || isJust mv
then Just $ Ast.AstMatvecmulS m n (fromMaybe u mu) (fromMaybe v mv)
else Nothing
Ast.AstMatmul2S m@SNat n@SNat p@SNat u v ->
let mu = substitute1Ast i var u
mv = substitute1Ast i var v
in if isJust mu || isJust mv
then Just $ Ast.AstMatmul2S m n p (fromMaybe u mu) (fromMaybe v mv)
else Nothing

_ -> error "TODO"

Expand Down
18 changes: 16 additions & 2 deletions src/HordeAd/Core/AstTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import Data.Array.Mixed.Shape
, shxTakeSSX
)
import Data.Array.Nested
(IShR, KnownShS (..), MapJust, Replicate, ShR (..), ShX (..))
(IShR, KnownShS (..), MapJust, Replicate, ShR (..), ShS (..), ShX (..))
import Data.Array.Nested.Internal.Shape
( shCvtRX
, shCvtSX
Expand Down Expand Up @@ -254,13 +254,21 @@ ftkAst t = case t of
FTKR _ x -> FTKR ZSR x
AstDot0R _ _u _v -> FTKR ZSR FTKScalar
AstDot1InR u _v -> case ftkAst u of
FTKR (n :$: _) FTKScalar -> FTKR (n :$: ZSR) FTKScalar
FTKR (m :$: _) FTKScalar -> FTKR (m :$: ZSR) FTKScalar
AstMatvecmulR u _v -> case ftkAst u of
FTKR (m :$: _) FTKScalar -> FTKR (m :$: ZSR) FTKScalar
AstMatmul2R u v -> case (ftkAst u, ftkAst v) of
(FTKR (m :$: _ :$: ZSR) FTKScalar, FTKR (_ :$: p :$: ZSR) FTKScalar) ->
FTKR (m :$: p :$: ZSR) FTKScalar
_ -> error "ftkAst: impossible pattern needlessly required"
AstReplicate0NS sh _ v -> case ftkAst v of
FTKS _ x -> FTKS sh x
AstSum0S _ _ v -> case ftkAst v of
FTKS _ x -> FTKS ZSS x
AstDot0S _ _u _v -> FTKS ZSS FTKScalar
AstDot1InS m@SNat _ _u _v -> FTKS (m :$$ ZSS) FTKScalar
AstMatvecmulS m@SNat _ _u _v -> FTKS (m :$$ ZSS) FTKScalar
AstMatmul2S m@SNat _ p@SNat _u _v -> FTKS (m :$$ p :$$ ZSS) FTKScalar

_ -> error "TODO"

Expand Down Expand Up @@ -428,6 +436,12 @@ varInAst var = \case
AstDot1InR u v -> varInAst var u || varInAst var v
AstMatvecmulR u v -> varInAst var u || varInAst var v
AstMatmul2R u v -> varInAst var u || varInAst var v
AstReplicate0NS _ _ v -> varInAst var v
AstSum0S _ _ v -> varInAst var v
AstDot0S _ u v -> varInAst var u || varInAst var v
AstDot1InS _ _ u v -> varInAst var u || varInAst var v
AstMatvecmulS _ _ u v -> varInAst var u || varInAst var v
AstMatmul2S _ _ _ u v -> varInAst var u || varInAst var v

varInIndex :: AstVarId -> AstIxR ms n -> Bool
varInIndex var = any (varInAst var)
Expand Down
6 changes: 6 additions & 0 deletions src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,12 @@ build1V snat@SNat (var, v0) =
Ast.AstDot1InR{} -> error "build1V: term not accessible from user API"
Ast.AstMatvecmulR{} -> error "build1V: term not accessible from user API"
Ast.AstMatmul2R{} -> error "build1V: term not accessible from user API"
Ast.AstReplicate0NS{} -> error "build1V: term not accessible from user API"
Ast.AstSum0S{} -> error "build1V: term not accessible from user API"
Ast.AstDot0S{} -> error "build1V: term not accessible from user API"
Ast.AstDot1InS{} -> error "build1V: term not accessible from user API"
Ast.AstMatvecmulS{} -> error "build1V: term not accessible from user API"
Ast.AstMatmul2S{} -> error "build1V: term not accessible from user API"

_ -> error $ "TODO: " ++ show v0

Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/OpsConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ instance BaseTensor RepN where

sscaleByScalar s v =
RepN $ liftVS (V.map (* Nested.sunScalar (unRepN s))) (unRepN v)
sdot1In proxy u v = RepN $ Nested.sdot1Inner proxy (unRepN u) (unRepN v)
sdot1In (SNat @n) u v = RepN $ Nested.sdot1Inner (Proxy @n) (unRepN u) (unRepN v)

sfromPrimal = id
sprimalPart = id
Expand Down
6 changes: 3 additions & 3 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -881,9 +881,9 @@ class ( Num (IntOf target)
=> target (TKS '[] r) -> target (TKS sh r) -> target (TKS sh r)
sscaleByScalar s v = v * sreplicate0N s
sdot1In :: (GoodScalar r, KnownNat n, KnownNat m)
=> Proxy m
-> target (TKS '[n, m] r) -> target (TKS '[n, m] r)
-> target (TKS '[n] r) -- TODO: generalize
=> SNat n
-> target (TKS '[m, n] r) -> target (TKS '[m, n] r)
-> target (TKS '[m] r) -- TODO: generalize
sdot1In _ t u = ssum $ str (t * u)

-- Primal/dual things.
Expand Down

0 comments on commit 65ce664

Please sign in to comment.