Skip to content

Commit

Permalink
Implement almost all shaped contractAst rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 1, 2025
1 parent 65ce664 commit 7b881b1
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 118 deletions.
4 changes: 2 additions & 2 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,8 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
(PermC perm, KnownShS sh, Rank perm <= Rank sh, TensorKind r)
=> Permutation.Perm perm -> AstTensor ms s (TKS2 sh r)
-> AstTensor ms s (TKS2 (Permutation.PermutePrefix perm sh) r)
AstReshapeS :: ( KnownShS sh, Nested.Product sh ~ Nested.Product sh2
, TensorKind r, KnownShS sh2 )
AstReshapeS :: ( KnownShS sh, KnownShS sh2
, Nested.Product sh ~ Nested.Product sh2, TensorKind r)
=> AstTensor ms s (TKS2 sh r) -> AstTensor ms s (TKS2 sh2 r)
-- beware that the order of type arguments is different than in orthotope
-- and than the order of value arguments in the ranked version
Expand Down
112 changes: 0 additions & 112 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -429,26 +429,6 @@ interpretAst !env = \case
AstFloorS v ->
sfloor $ sfromPrimal $ interpretAstPrimalSRuntimeSpecialized env v
AstIotaS -> siota
{- TODO:
AstN2R TimesOp [v, AstLet var u (AstReshape sh (AstReplicate @m k s))]
| Just Refl <- sameNat (Proxy @m) (Proxy @0), not (varInAst var v) ->
-- The varInAst check is needed, because although variable
-- capture is impossible, because we don't create nested lets
-- with the same variable, we could create such nested lets
-- if we omitted this check.
interpretAst env
(AstLet var u (AstN2R TimesOp [v, AstReshape sh
(AstReplicate @m k s)]))
AstN2R TimesOp [v, AstReshape sh (AstLet var u (AstReplicate @m k s))]
| Just Refl <- sameNat (Proxy @m) (Proxy @0), not (varInAst var v) ->
interpretAst env
(AstLet var u (AstN2R TimesOp [v, AstReshape sh
(AstReplicate @m k s)]))
AstN2R TimesOp [v, AstLet var u (AstReplicate @m k s)]
| Just Refl <- sameNat (Proxy @m) (Proxy @0), not (varInAst var v) ->
interpretAst env
(AstLet var u (AstN2R TimesOp [v, AstReplicate @m k s]))
-}
AstN1S opCode u ->
let u2 = interpretAst env u
in interpretAstN1 opCode u2
Expand Down Expand Up @@ -479,90 +459,6 @@ interpretAst !env = \case
-- value of the correct rank and shape; this is needed, because
-- vectorization can produce out of bound indexing from code where
-- the indexing is guarded by conditionals
{- TODO:
AstSum (AstN2R TimesOp [ AstLet vart vt (AstTranspose tperm t)
, AstTranspose uperm u ]) ->
interpretAst env
(AstLet vart vt
(AstSum (AstN2R TimesOp [ AstTranspose tperm t
, AstTranspose uperm u ])))
AstSum (AstN2R TimesOp [ AstTranspose tperm t
, AstLet varu vu (AstTranspose uperm u) ]) ->
interpretAst env
(AstLet varu vu
(AstSum (AstN2R TimesOp [ AstTranspose tperm t
, AstTranspose uperm u ])))
AstSum (AstN2R TimesOp [ AstLet vart vt (AstTranspose tperm t)
, AstLet varu vu (AstTranspose uperm u) ]) ->
interpretAst env
(AstLet vart vt (AstLet varu vu
(AstSum (AstN2R TimesOp [ AstTranspose tperm t
, AstTranspose uperm u ]))))
AstSum v@(AstN2R TimesOp [ AstTranspose tperm (AstReplicate _tk t)
, AstTranspose uperm (AstReplicate _uk u) ])
| Just Refl <- sameNat (Proxy @n) (Proxy @2) ->
let interpretMatmul2 t1 u1 =
let t2 = interpretAst env t1
u2 = interpretAst env u1
in rmatmul2 t2 u2
in case (tperm, uperm) of
([2, 1, 0], [1, 0]) -> -- tk and uk are fine due to perms matching
interpretMatmul2 t u
([1, 0], [2, 1, 0]) ->
interpretMatmul2 u t
([2, 1, 0], [2, 0, 1]) ->
interpretMatmul2 t (astTranspose [1, 0] u)
([2, 0, 1], [2, 1, 0]) ->
interpretMatmul2 u (astTranspose [1, 0] t)
([1, 2, 0], [1, 0]) ->
interpretMatmul2 (astTranspose [1, 0] t) u
([1, 0], [1, 2, 0]) ->
interpretMatmul2 (astTranspose [1, 0] u) t
-- The variants below emerge when the whole term is transposed.
-- All overlap with variants above and the cheaper one is selected.
([2, 0, 1], [1, 2, 0]) ->
ttr
$ interpretMatmul2 t u
([1, 2, 0], [2, 0, 1]) ->
ttr
$ interpretMatmul2 u t
_ -> rsum $ interpretAst env v
AstSum (AstN2R TimesOp [t, u])
| Just Refl <- sameNat (Proxy @n) (Proxy @0) ->
let t1 = interpretAst env t
t2 = interpretAst env u
in rdot0 t1 t2
-- TODO: do as a term rewrite using an extended set of terms?
AstSum (AstReshape _sh (AstN2R TimesOp [t, u]))
| Just Refl <- sameNat (Proxy @n) (Proxy @0) ->
let t1 = interpretAst env t
t2 = interpretAst env u
in rdot0 t1 t2
AstSum (AstTranspose [1, 0] (AstN2R TimesOp [t, u])) -- TODO: generalize
| Just Refl <- sameNat (Proxy @n) (Proxy @1) ->
let t1 = interpretAst env t
t2 = interpretAst env u
in rdot1In t1 t2
AstSum (AstReshape sh (AstTranspose _ t))
| Just Refl <- sameNat (Proxy @n) (Proxy @0) ->
interpretAst env (AstSum (AstReshape sh t))
AstSum (AstReshape sh (AstReverse t))
| Just Refl <- sameNat (Proxy @n) (Proxy @0) ->
interpretAst env (AstSum (AstReshape sh t))
AstSum (AstReshape _sh (AstSum t))
| Just Refl <- sameNat (Proxy @n) (Proxy @0) ->
rsum0 $ interpretAst env t
AstSum (AstSum t)
| Just Refl <- sameNat (Proxy @n) (Proxy @0) ->
rsum0 $ interpretAst env t
-- more cases are needed so perhaps we need AstSum0
AstSum (AstLet var v t) -> interpretAst env (AstLet var v (AstSum t))
AstSum (AstReshape sh (AstLet var v t)) ->
interpretAst env (AstLet var v (AstSum (AstReshape sh t)))
AstSum (AstReshape _sh t)
| Just Refl <- sameNat (Proxy @n) (Proxy @0) ->
rsum0 $ interpretAst env t
-}
AstScatterS @_ @shn @shp v (ZS, ix) ->
withKnownShS (knownShS @shp `shsAppend` knownShS @shn) $
soneHot (interpretAst env v) (interpretAstPrimal env <$> ix)
Expand All @@ -573,14 +469,6 @@ interpretAst !env = \case
AstFromVectorS l ->
let l2 = V.map (interpretAst env) l
in sfromVector l2
{- TODO:
AstReshape sh (AstReplicate @m _ s)
| Just Refl <- sameNat (Proxy @m) (Proxy @0) ->
let t1 = interpretAst env s
in rreplicate0N sh t1
AstReshape sh (AstLet var v (AstReplicate k t)) ->
interpretAst env (AstLet var v (AstReshape sh (AstReplicate k t)))
-}
AstAppendS x y ->
let t1 = interpretAst env x
t2 = interpretAst env y
Expand Down
Loading

0 comments on commit 7b881b1

Please sign in to comment.