From 7b881b1e60d45bc24ef3ff05140fe2a78d751d4c Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Wed, 1 Jan 2025 23:34:20 +0100 Subject: [PATCH] Implement almost all shaped contractAst rewrites --- src/HordeAd/Core/Ast.hs | 4 +- src/HordeAd/Core/AstInterpret.hs | 112 --------------- src/HordeAd/Core/AstSimplify.hs | 231 ++++++++++++++++++++++++++++++- 3 files changed, 229 insertions(+), 118 deletions(-) diff --git a/src/HordeAd/Core/Ast.hs b/src/HordeAd/Core/Ast.hs index c18a1382..726a7b05 100644 --- a/src/HordeAd/Core/Ast.hs +++ b/src/HordeAd/Core/Ast.hs @@ -474,8 +474,8 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType (PermC perm, KnownShS sh, Rank perm <= Rank sh, TensorKind r) => Permutation.Perm perm -> AstTensor ms s (TKS2 sh r) -> AstTensor ms s (TKS2 (Permutation.PermutePrefix perm sh) r) - AstReshapeS :: ( KnownShS sh, Nested.Product sh ~ Nested.Product sh2 - , TensorKind r, KnownShS sh2 ) + AstReshapeS :: ( KnownShS sh, KnownShS sh2 + , Nested.Product sh ~ Nested.Product sh2, TensorKind r) => AstTensor ms s (TKS2 sh r) -> AstTensor ms s (TKS2 sh2 r) -- beware that the order of type arguments is different than in orthotope -- and than the order of value arguments in the ranked version diff --git a/src/HordeAd/Core/AstInterpret.hs b/src/HordeAd/Core/AstInterpret.hs index 06da5659..952627e4 100644 --- a/src/HordeAd/Core/AstInterpret.hs +++ b/src/HordeAd/Core/AstInterpret.hs @@ -429,26 +429,6 @@ interpretAst !env = \case AstFloorS v -> sfloor $ sfromPrimal $ interpretAstPrimalSRuntimeSpecialized env v AstIotaS -> siota -{- TODO: - AstN2R TimesOp [v, AstLet var u (AstReshape sh (AstReplicate @m k s))] - | Just Refl <- sameNat (Proxy @m) (Proxy @0), not (varInAst var v) -> - -- The varInAst check is needed, because although variable - -- capture is impossible, because we don't create nested lets - -- with the same variable, we could create such nested lets - -- if we omitted this check. - interpretAst env - (AstLet var u (AstN2R TimesOp [v, AstReshape sh - (AstReplicate @m k s)])) - AstN2R TimesOp [v, AstReshape sh (AstLet var u (AstReplicate @m k s))] - | Just Refl <- sameNat (Proxy @m) (Proxy @0), not (varInAst var v) -> - interpretAst env - (AstLet var u (AstN2R TimesOp [v, AstReshape sh - (AstReplicate @m k s)])) - AstN2R TimesOp [v, AstLet var u (AstReplicate @m k s)] - | Just Refl <- sameNat (Proxy @m) (Proxy @0), not (varInAst var v) -> - interpretAst env - (AstLet var u (AstN2R TimesOp [v, AstReplicate @m k s])) --} AstN1S opCode u -> let u2 = interpretAst env u in interpretAstN1 opCode u2 @@ -479,90 +459,6 @@ interpretAst !env = \case -- value of the correct rank and shape; this is needed, because -- vectorization can produce out of bound indexing from code where -- the indexing is guarded by conditionals -{- TODO: - AstSum (AstN2R TimesOp [ AstLet vart vt (AstTranspose tperm t) - , AstTranspose uperm u ]) -> - interpretAst env - (AstLet vart vt - (AstSum (AstN2R TimesOp [ AstTranspose tperm t - , AstTranspose uperm u ]))) - AstSum (AstN2R TimesOp [ AstTranspose tperm t - , AstLet varu vu (AstTranspose uperm u) ]) -> - interpretAst env - (AstLet varu vu - (AstSum (AstN2R TimesOp [ AstTranspose tperm t - , AstTranspose uperm u ]))) - AstSum (AstN2R TimesOp [ AstLet vart vt (AstTranspose tperm t) - , AstLet varu vu (AstTranspose uperm u) ]) -> - interpretAst env - (AstLet vart vt (AstLet varu vu - (AstSum (AstN2R TimesOp [ AstTranspose tperm t - , AstTranspose uperm u ])))) - AstSum v@(AstN2R TimesOp [ AstTranspose tperm (AstReplicate _tk t) - , AstTranspose uperm (AstReplicate _uk u) ]) - | Just Refl <- sameNat (Proxy @n) (Proxy @2) -> - let interpretMatmul2 t1 u1 = - let t2 = interpretAst env t1 - u2 = interpretAst env u1 - in rmatmul2 t2 u2 - in case (tperm, uperm) of - ([2, 1, 0], [1, 0]) -> -- tk and uk are fine due to perms matching - interpretMatmul2 t u - ([1, 0], [2, 1, 0]) -> - interpretMatmul2 u t - ([2, 1, 0], [2, 0, 1]) -> - interpretMatmul2 t (astTranspose [1, 0] u) - ([2, 0, 1], [2, 1, 0]) -> - interpretMatmul2 u (astTranspose [1, 0] t) - ([1, 2, 0], [1, 0]) -> - interpretMatmul2 (astTranspose [1, 0] t) u - ([1, 0], [1, 2, 0]) -> - interpretMatmul2 (astTranspose [1, 0] u) t - -- The variants below emerge when the whole term is transposed. - -- All overlap with variants above and the cheaper one is selected. - ([2, 0, 1], [1, 2, 0]) -> - ttr - $ interpretMatmul2 t u - ([1, 2, 0], [2, 0, 1]) -> - ttr - $ interpretMatmul2 u t - _ -> rsum $ interpretAst env v - AstSum (AstN2R TimesOp [t, u]) - | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> - let t1 = interpretAst env t - t2 = interpretAst env u - in rdot0 t1 t2 - -- TODO: do as a term rewrite using an extended set of terms? - AstSum (AstReshape _sh (AstN2R TimesOp [t, u])) - | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> - let t1 = interpretAst env t - t2 = interpretAst env u - in rdot0 t1 t2 - AstSum (AstTranspose [1, 0] (AstN2R TimesOp [t, u])) -- TODO: generalize - | Just Refl <- sameNat (Proxy @n) (Proxy @1) -> - let t1 = interpretAst env t - t2 = interpretAst env u - in rdot1In t1 t2 - AstSum (AstReshape sh (AstTranspose _ t)) - | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> - interpretAst env (AstSum (AstReshape sh t)) - AstSum (AstReshape sh (AstReverse t)) - | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> - interpretAst env (AstSum (AstReshape sh t)) - AstSum (AstReshape _sh (AstSum t)) - | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> - rsum0 $ interpretAst env t - AstSum (AstSum t) - | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> - rsum0 $ interpretAst env t - -- more cases are needed so perhaps we need AstSum0 - AstSum (AstLet var v t) -> interpretAst env (AstLet var v (AstSum t)) - AstSum (AstReshape sh (AstLet var v t)) -> - interpretAst env (AstLet var v (AstSum (AstReshape sh t))) - AstSum (AstReshape _sh t) - | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> - rsum0 $ interpretAst env t --} AstScatterS @_ @shn @shp v (ZS, ix) -> withKnownShS (knownShS @shp `shsAppend` knownShS @shn) $ soneHot (interpretAst env v) (interpretAstPrimal env <$> ix) @@ -573,14 +469,6 @@ interpretAst !env = \case AstFromVectorS l -> let l2 = V.map (interpretAst env) l in sfromVector l2 -{- TODO: - AstReshape sh (AstReplicate @m _ s) - | Just Refl <- sameNat (Proxy @m) (Proxy @0) -> - let t1 = interpretAst env s - in rreplicate0N sh t1 - AstReshape sh (AstLet var v (AstReplicate k t)) -> - interpretAst env (AstLet var v (AstReshape sh (AstReplicate k t))) --} AstAppendS x y -> let t1 = interpretAst env x t2 = interpretAst env y diff --git a/src/HordeAd/Core/AstSimplify.hs b/src/HordeAd/Core/AstSimplify.hs index 2aff3dd4..3a83610c 100644 --- a/src/HordeAd/Core/AstSimplify.hs +++ b/src/HordeAd/Core/AstSimplify.hs @@ -91,6 +91,7 @@ import Data.Array.Nested , KnownShX (..) , ListS (..) , MapJust + , Product , Rank , Replicate , SMayNat (..) @@ -3868,6 +3869,18 @@ contractAst t = case t of TimesOp (Ast.AstTranspose tperm (contractAst t2)) (Ast.AstTranspose uperm (contractAst u))))) + Ast.AstSum snat stk (AstN2S TimesOp + (Ast.AstLet vart vt (Ast.AstTransposeS tperm t2)) + (Ast.AstTransposeS uperm u)) + | Dict <- lemTensorKindOfSTK stk -> + (Ast.AstLet + vart + (contractAst vt) + (contractAst $ Ast.AstSum -- the crucial exposed redex + snat stk (AstN2S + TimesOp + (Ast.AstTransposeS tperm (contractAst t2)) + (Ast.AstTransposeS uperm (contractAst u))))) Ast.AstSum snat stk (AstN2R TimesOp (Ast.AstTranspose tperm t2) (Ast.AstLet varu vu (Ast.AstTranspose uperm u))) @@ -3880,6 +3893,18 @@ contractAst t = case t of TimesOp (Ast.AstTranspose tperm (contractAst t2)) (Ast.AstTranspose uperm (contractAst u))))) + Ast.AstSum snat stk (AstN2S TimesOp + (Ast.AstTransposeS tperm t2) + (Ast.AstLet varu vu (Ast.AstTransposeS uperm u))) + | Dict <- lemTensorKindOfSTK stk -> + (Ast.AstLet + varu + (contractAst vu) + (contractAst $ Ast.AstSum -- the crucial exposed redex + snat stk (AstN2S + TimesOp + (Ast.AstTransposeS tperm (contractAst t2)) + (Ast.AstTransposeS uperm (contractAst u))))) Ast.AstSum snat stk (AstN2R TimesOp (Ast.AstLet vart vt (Ast.AstTranspose tperm t2)) (Ast.AstLet varu vu (Ast.AstTranspose uperm u))) @@ -3895,6 +3920,21 @@ contractAst t = case t of TimesOp (Ast.AstTranspose tperm (contractAst t2)) (Ast.AstTranspose uperm (contractAst u)))))) + Ast.AstSum snat stk (AstN2S TimesOp + (Ast.AstLet vart vt (Ast.AstTransposeS tperm t2)) + (Ast.AstLet varu vu (Ast.AstTransposeS uperm u))) + | Dict <- lemTensorKindOfSTK stk -> + (Ast.AstLet + vart + (contractAst vt) + (Ast.AstLet + varu + (contractAst vu) + (contractAst $ Ast.AstSum -- the crucial exposed redex + snat stk (AstN2S + TimesOp + (Ast.AstTransposeS tperm (contractAst t2)) + (Ast.AstTransposeS uperm (contractAst u)))))) Ast.AstSum snat stk@(STKR (SNat @n) (STKScalar rRep)) v@(AstN2R TimesOp (Ast.AstTranspose tperm (Ast.AstReplicate _tk stkt t2)) (Ast.AstTranspose uperm (Ast.AstReplicate _uk stku u2))) @@ -3951,31 +3991,136 @@ contractAst t = case t of -- Ast.AstTranspose [1, 0] -- $ interpretMatmul2 (astTranspose [1, 0] u2) (astTranspose [1, 0] t2) _ -> astSum snat stk (contractAst v) +{-TODO: Ast.AstSum snat stk@(STKS (_ :$$ _ :$$ ZSS) (STKScalar @r rRep)) + v@(AstN2S TimesOp (Ast.AstTransposeS tperm (Ast.AstReplicate _tk stkt t2)) + (Ast.AstTransposeS uperm (Ast.AstReplicate _uk stku u2))) + -> case (stkt, stku) of + (STKS{}, STKS{}) -> + let perm10 = Permutation.makePerm @'[1, 0] + attemptMatmul2 :: forall m' n' p'. + AstTensor AstMethodLet s (TKS '[m', n'] r) + -> AstTensor AstMethodLet s (TKS '[n', p'] r) + -> AstTensor AstMethodLet s (TKS '[m', p'] r) + attemptMatmul2 t3 u3 = + let t4 = contractAst t3 + u4 = contractAst u3 + in case testEquality rRep (typeRep @Double) of + Just Refl -> Ast.AstMatmul2S + (SNat @m') (SNat @n') (SNat @p') t4 u4 + _ -> case testEquality rRep (typeRep @Float) of + Just Refl -> Ast.AstMatmul2S + (SNat @m') (SNat @n') (SNat @p') t4 u4 + _ -> case testEquality rRep (typeRep @Int64) of + Just Refl -> Ast.AstMatmul2S + (SNat @m') (SNat @n') (SNat @p') t4 u4 + _ -> case testEquality rRep (typeRep @CInt) of + Just Refl -> Ast.AstMatmul2S + (SNat @m') (SNat @n') (SNat @p') t4 u4 + _ -> astSum snat stk (contractAst v) + in if | Just Refl <- geq tperm (Permutation.makePerm @'[2, 1, 0]) + , Just Refl <- geq uperm (Permutation.makePerm @'[1, 0]) -> + -- tk and uk are fine due to perms matching + attemptMatmul2 t2 u2 + + ([1, 0], [2, 1, 0]) -> + attemptMatmul2 u2 t2 + ([2, 1, 0], [2, 0, 1]) -> + attemptMatmul2 t2 (astTransposeS perm10 u2) + ([2, 0, 1], [2, 1, 0]) -> + attemptMatmul2 u2 (astTransposeS perm10 t2) + ([1, 2, 0], [1, 0]) -> + attemptMatmul2 (astTransposeS perm10 t2) u2 + ([1, 0], [1, 2, 0]) -> + attemptMatmul2 (astTransposeS perm10 u2) t2 +-- ([1, 2, 0], [2, 0, 1]) -> +-- attemptMatmul2 (astTransposeS perm10 t2) +-- (astTransposeS perm10 u2) +-- ([2, 0, 1], [1, 2, 0]) -> +-- attemptMatmul2 (astTransposeS perm10 u2) +-- (astTransposeS perm10 t2) + -- The variants below emerge when the whole term is transposed. + -- All overlap with variants above and the cheaper one is selected. + ([2, 0, 1], [1, 2, 0]) -> + Ast.AstTransposeS perm10 $ attemptMatmul2 t2 u2 + ([1, 2, 0], [2, 0, 1]) -> + Ast.AstTransposeS perm10 $ attemptMatmul2 u2 t2 +-- ([2, 0, 1], [2, 1, 0]) -> +-- Ast.AstTranspose [1, 0] +-- $ attemptMatmul2 t2 (astTransposeS perm10 u2) +-- ([2, 1, 0], [2, 0, 1]) -> +-- Ast.AstTranspose [1, 0] +-- $ attemptMatmul2 u2 (astTransposeS perm10 t2) +-- ([1, 2, 0], [1, 0]) -> +-- Ast.AstTranspose [1, 0] +-- $ attemptMatmul2 (astTransposeS perm10 u2) t2 +-- ([1, 0], [2, 1, 0]) -> +-- Ast.AstTransposeS perm10 +-- $ attemptMatmul2 (astTransposeS perm10 t2) +-- (astTransposeS perm10 u2) +-- ([2, 1, 0], [1, 0]) -> +-- Ast.AstTransposeS perm10 +-- $ attemptMatmul2 (astTranspose 0S perm10 u2) +-- (astTransposeS perm10 t2) + _ -> astSum snat stk (contractAst v) -} Ast.AstSum _ (STKR (SNat @n) _) (AstN2R TimesOp t2 u) | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> Ast.AstDot0R (SNat @1) (contractAst t2) (contractAst u) + Ast.AstSum k@SNat (STKS ZSS _) (AstN2S TimesOp t2 u) -> + Ast.AstDot0S (k :$$ ZSS) (contractAst t2) (contractAst u) Ast.AstSum _ (STKR (SNat @n) _) (Ast.AstReshape @p _sh (AstN2R TimesOp t2 u)) | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> Ast.AstDot0R (SNat @p) (contractAst t2) (contractAst u) + Ast.AstSum _ (STKS ZSS _) (Ast.AstReshapeS @sh2 (AstN2S TimesOp t2 u)) -> + Ast.AstDot0S (knownShS @sh2) (contractAst t2) (contractAst u) Ast.AstSum _ (STKR (SNat @n) _) (Ast.AstTranspose [1, 0] (AstN2R TimesOp t2 u)) -- TODO: generalize | Just Refl <- sameNat (Proxy @n) (Proxy @1) -> Ast.AstDot1InR (contractAst t2) (contractAst u) + Ast.AstSum + n@(SNat @n) + (STKS (m@(SNat @m) :$$ ZSS) _) + (Ast.AstTransposeS @perm @sh + (SNat @n1 `Permutation.PCons` SNat @n0 + `Permutation.PCons` Permutation.PNil) + (AstN2S TimesOp t2 u)) + -- TODO: generalize +-- TODO: | Just Refl <- testEquality perm (Permutation.makePerm @'[1, 0]) -> + | Just Refl <- sameNat (Proxy @n0) (Proxy @0) + , Just Refl <- sameNat (Proxy @n1) (Proxy @1) -> + -- TODO: Why is this needed? Would a more general lemma suffice? + gcastWith (unsafeCoerceRefl + :: Permutation.PermutePrefix perm [n, m] :~: sh) $ + Ast.AstDot1InS m n (contractAst t2) (contractAst u) Ast.AstSum snat stk@(STKR (SNat @n) _) (Ast.AstReshape sh (Ast.AstTranspose _ t2)) | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> contractAst (Ast.AstSum snat stk (Ast.AstReshape sh t2)) + Ast.AstSum + snat stk@(STKS ZSS _) (Ast.AstReshapeS + @sh3 @sh (Ast.AstTransposeS @_ @sh2 _ t2)) -> + gcastWith (unsafeCoerceRefl :: Product sh2 :~: Product sh3) $ + contractAst (Ast.AstSum snat stk (Ast.AstReshapeS @sh2 @sh t2)) Ast.AstSum snat stk@(STKR (SNat @n) _) (Ast.AstReshape sh (Ast.AstReverse t2)) | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> contractAst (Ast.AstSum snat stk (Ast.AstReshape sh t2)) + Ast.AstSum snat stk@(STKS ZSS _) (Ast.AstReshapeS + @_ @sh (Ast.AstReverseS t2)) -> + contractAst (Ast.AstSum snat stk (Ast.AstReshapeS @_ @sh t2)) Ast.AstSum _ (STKR (SNat @n) x) (Ast.AstReshape @p _sh (Ast.AstSum _ _ t2)) | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> Ast.AstSum0R (SNat @(1 + p)) x (contractAst t2) + Ast.AstSum _ (STKS ZSS x) (Ast.AstReshapeS @sh2 (Ast.AstSum k2@SNat _ t2)) -> + Ast.AstSum0S (k2 :$$ knownShS @sh2) x (contractAst t2) Ast.AstSum _ (STKR (SNat @n) x) (Ast.AstSum _ _ t2) | Just Refl <- sameNat (Proxy @n) (Proxy @0) , Dict <- lemTensorKindOfSTK x -> - Ast.AstSum0R (SNat @(2 + n)) x (contractAst t2) -- TODO: more cases are needed + Ast.AstSum0R (SNat @(2 + n)) x (contractAst t2) + -- TODO: more cases are needed + Ast.AstSum k@SNat (STKS ZSS x) (Ast.AstSum k2@SNat _ t2) + | Dict <- lemTensorKindOfSTK x -> + Ast.AstSum0S (k2 :$$ k :$$ ZSS) x (contractAst t2) + -- TODO: more cases are needed Ast.AstSum snat stk (Ast.AstLet var v t2) | Dict <- lemTensorKindOfSTK stk -> contractAst (Ast.AstLet var v (Ast.AstSum snat stk t2)) Ast.AstSum snat stk (Ast.AstReshape sh (Ast.AstLet var v t2)) @@ -3984,9 +4129,17 @@ contractAst t = case t of var v (Ast.AstSum snat stk (Ast.AstReshape sh t2))) + Ast.AstSum snat stk (Ast.AstReshapeS @sh (Ast.AstLet var v t2)) + | Dict <- lemTensorKindOfSTK stk -> + contractAst (Ast.AstLet + var + v + (Ast.AstSum snat stk (Ast.AstReshapeS @sh t2))) Ast.AstSum _ (STKR (SNat @n) x) (Ast.AstReshape @p _sh t2) | Just Refl <- sameNat (Proxy @n) (Proxy @0) -> Ast.AstSum0R (SNat @p) x (contractAst t2) + Ast.AstSum _ (STKS ZSS x) (Ast.AstReshapeS @sh2 t2) -> + Ast.AstSum0S (knownShS @sh2) x (contractAst t2) Ast.AstSum snat stk v | Dict <- lemTensorKindOfBuild snat stk -> astSum snat stk (contractAst v) Ast.AstReplicate snat stk v | Dict <- lemTensorKindOfSTK stk -> @@ -4001,18 +4154,43 @@ contractAst t = case t of , Just Refl <- sameNat (Proxy @n) (Proxy @0) , var == var2, sNatValue snat == lengthAst u -> Ast.AstMatvecmulR (contractAst u) (contractAst t2) + Ast.AstBuild1 @y2 + snat (var, Ast.AstSum + n _ + (AstN2S + TimesOp + t2 + (Ast.AstIndexS + u (((:.$) @m (AstIntVar var2) ZIS))))) + | STKS ZSS _ <- stensorKind @y2 + , Just Refl <- geq snat (SNat @m) + , var == var2 -> + Ast.AstMatvecmulS snat n (contractAst u) (contractAst t2) Ast.AstBuild1 @y2 snat (var, Ast.AstSum _ _ (Ast.AstReshape @p _sh (AstN2R TimesOp t2 (Ast.AstIndex - u (AstIntVar - var2 :.: ZIR))))) + u (AstIntVar var2 + :.: ZIR))))) | STKR (SNat @n) _ <- stensorKind @y2 , Just Refl <- sameNat (Proxy @n) (Proxy @0) , Just Refl <- sameNat (Proxy @p) (Proxy @1) , var == var2, sNatValue snat == lengthAst u -> Ast.AstMatvecmulR (contractAst u) (contractAst t2) + Ast.AstBuild1 + @y2 snat (var, Ast.AstSum _ _ + (Ast.AstReshapeS + @sh (AstN2S + TimesOp + t2 + (Ast.AstIndexS + u (((:.$) @m (AstIntVar var2) ZIS)))))) + | STKS ZSS _ <- stensorKind @y2 + , n :$$ ZSS <- knownShS @sh + , Just Refl <- geq snat (SNat @m) + , var == var2 -> + Ast.AstMatvecmulS snat n (contractAst u) (contractAst t2) Ast.AstBuild1 k (var, v) -> Ast.AstBuild1 k (var, contractAst v) Ast.AstLet var u v -> astLet var (contractAst u) (contractAst v) AstConcrete{} -> t @@ -4057,7 +4235,23 @@ contractAst t = case t of var (contractAst u) (AstN2R - TimesOp v (Ast.AstReshape sh + TimesOp v (Ast.AstReshape sh + (Ast.AstReplicate + (SNat @m) stk (contractAst s)))) + AstN2S TimesOp v (Ast.AstLet var u + (Ast.AstReshapeS @_ @sh + (Ast.AstReplicate (SNat @m) stk s))) + | Just Refl <- sameNat (Proxy @m) (Proxy @0), not (varNameInAst var v) + , Dict <- lemTensorKindOfSTK stk -> + -- The varNameInAst check is needed, because although variable + -- capture is impossible, because we don't create nested lets + -- with the same variable, we could create such nested lets + -- if we omitted this check. + Ast.AstLet + var + (contractAst u) + (AstN2S + TimesOp v (Ast.AstReshapeS @_ @sh (Ast.AstReplicate (SNat @m) stk (contractAst s)))) AstN2R TimesOp v (Ast.AstReshape sh @@ -4072,6 +4266,18 @@ contractAst t = case t of TimesOp v (astReshape sh (Ast.AstReplicate (SNat @m) stk (contractAst s)))) + AstN2S TimesOp v (Ast.AstReshapeS @_ @sh + (Ast.AstLet + var u (Ast.AstReplicate (SNat @m) stk s))) + | Just Refl <- sameNat (Proxy @m) (Proxy @0), not (varNameInAst var v) + , Dict <- lemTensorKindOfSTK stk -> + Ast.AstLet + var + (contractAst u) + (AstN2S + TimesOp v (astReshapeS @_ @sh + (Ast.AstReplicate + (SNat @m) stk (contractAst s)))) AstN2R TimesOp v (Ast.AstLet var u (Ast.AstReplicate (SNat @m) stk s)) | Just Refl <- sameNat (Proxy @m) (Proxy @0), not (varNameInAst var v) , Dict <- lemTensorKindOfSTK stk -> @@ -4081,6 +4287,15 @@ contractAst t = case t of (AstN2R TimesOp v (Ast.AstReplicate (SNat @m) stk (contractAst s))) + AstN2S TimesOp v (Ast.AstLet var u (Ast.AstReplicate (SNat @m) stk s)) + | Just Refl <- sameNat (Proxy @m) (Proxy @0), not (varNameInAst var v) + , Dict <- lemTensorKindOfSTK stk -> + Ast.AstLet + var + (contractAst u) + (AstN2S + TimesOp v (Ast.AstReplicate + (SNat @m) stk (contractAst s))) AstN2R opCode u v -> AstN2R opCode (contractAst u) (contractAst v) Ast.AstR1R opCode u -> Ast.AstR1R opCode (contractAst u) Ast.AstR2R opCode u v -> Ast.AstR2R opCode (contractAst u) (contractAst v) @@ -4098,12 +4313,20 @@ contractAst t = case t of Ast.AstReshape sh (Ast.AstReplicate _ (STKR @m _ x) s) | Just Refl <- sameNat (Proxy @m) (Proxy @0) -> Ast.AstReplicate0NR sh x (contractAst s) + Ast.AstReshapeS @_ @sh (Ast.AstReplicate _ (STKS ZSS x) s) -> + Ast.AstReplicate0NS (knownShS @sh) x (contractAst s) Ast.AstReshape sh (Ast.AstLet var v (Ast.AstReplicate snat stk t2)) | Dict <- lemTensorKindOfSTK stk -> Ast.AstLet var (contractAst v) (astReshape sh (Ast.AstReplicate snat stk (contractAst t2))) + Ast.AstReshapeS @_ @sh (Ast.AstLet var v (Ast.AstReplicate snat stk t2)) + | Dict <- lemTensorKindOfSTK stk -> + Ast.AstLet + var + (contractAst v) + (astReshapeS @_ @sh (Ast.AstReplicate snat stk (contractAst t2))) Ast.AstReshape sh v -> astReshape sh (contractAst v) Ast.AstGather sh v (vars, ix) -> astGatherR sh (contractAst v) (vars, contractAstIxR ix)