Skip to content

Commit

Permalink
Mark AstSFromR as normal form to prevent useless astReshapeAsGatherS …
Browse files Browse the repository at this point in the history
…expansion
  • Loading branch information
Mikolaj committed Jan 18, 2025
1 parent 783357a commit 6e5852b
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 15 deletions.
4 changes: 4 additions & 0 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3998,6 +3998,8 @@ expandAst t = case t of
Ast.AstScatterS @_ @_ @shp _ _
| gcompare (Permutation.permRank perm)
(shsRank $ knownShS @shp) == GGT -> t -- nf
Ast.AstSFromR{} -> t -- normal form
Ast.AstSFromX{} -> t -- normal form
_ -> -- not nf, let's express all as a gather
astTransposeAsGatherS (defaultKnobs {knobExpand = True})
perm -- TODO: (normalizePermutation perm)
Expand All @@ -4017,6 +4019,8 @@ expandAst t = case t of
Ast.AstR1S _ w | isVar w -> t -- normal form
Ast.AstR2S _ x y | isVar x && isVar y -> t -- normal form
Ast.AstScatterS{} -> t -- normal form
Ast.AstSFromR{} -> t -- normal form
Ast.AstSFromX{} -> t -- normal form
_ -> -- not nf, let's express all as a gather
astReshapeAsGatherS (defaultKnobs {knobExpand = True})
(expandAst v)
Expand Down
8 changes: 4 additions & 4 deletions test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ testReluSimplerPP4 = do
resetVarCounter
let (artifactRev, _deltas) = revArtifactAdapt True reluT2 (rreplicate0N [3, 4] (rscalar 128), rscalar 42)
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\m11 x1 -> tfromS (let m12 = sgather (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0])) (\\[i8, i9] -> [ifF (rfromS (sfromR (tproject1 m1) !$ [i8, i9] * sfromR (tproject2 m1)) <=. rfromS (sscalar 0.0)) 0 1]) * sfromR m11 in tpair (sreplicate (sreplicate (sfromR (tproject2 m1))) * m12, sdot0 (sgather (sfromR (tproject1 m1)) (\\[i13] -> [remF (quotF i13 4) 3, remF i13 4])) (sreshape m12)))"
@?= "\\m11 x1 -> tfromS (let m12 = sgather (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0])) (\\[i8, i9] -> [ifF (rfromS (sfromR (tproject1 m1) !$ [i8, i9] * sfromR (tproject2 m1)) <=. rfromS (sscalar 0.0)) 0 1]) * sfromR m11 in tpair (sreplicate (sreplicate (sfromR (tproject2 m1))) * m12, sdot0 (sreshape (sfromR (tproject1 m1))) (sreshape m12)))"
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\x1 -> rfromS (let m7 = sfromR (tproject1 m1) * sreplicate (sreplicate (sfromR (tproject2 m1))) in sgather (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0])) (\\[i8, i9] -> [ifF (rfromS (m7 !$ [i8, i9]) <=. rfromS (sscalar 0.0)) 0 1]) * m7)"

Expand Down Expand Up @@ -1119,7 +1119,7 @@ testDot2PP :: Assertion
testDot2PP = do
resetVarCounter >> resetIdCounter
let renames = IM.empty
(artifactRev, deltas) =
(artifactRev, _deltas) =
revArtifactAdapt True (uncurry (rdot0 @(AstTensor AstMethodLet FullSpan) @Double @2))
( ringestData [2,3] [1 .. 6]
, ringestData [2,3] [7 .. 12] )
Expand All @@ -1130,7 +1130,7 @@ testDot2PP = do
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\x4 x1 -> tfromS (tpair (sfromR (tproject2 m1) * sreplicate (sreplicate (sfromR x4)), sfromR (tproject1 m1) * sreplicate (sreplicate (sfromR x4))))"
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\x1 -> rfromS (sdot0 (sgather (sfromR (tproject1 m1)) (\\[i15] -> [remF (quotF i15 3) 2, remF i15 3])) (sgather (sfromR (tproject2 m1)) (\\[i16] -> [remF (quotF i16 3) 2, remF i16 3])))"
@?= "\\x1 -> rfromS (sdot0 (sreshape (sfromR (tproject1 m1))) (sreshape (sfromR (tproject2 m1))))"

testMatvecmulPP :: Assertion
testMatvecmulPP = do
Expand All @@ -1148,7 +1148,7 @@ testMatvecmulPP = do
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\v3 x1 -> tfromS (tpair (sreplicate (sfromR (tproject2 m1)) * stranspose (sreplicate (sfromR v3)), ssum (sfromR (tproject1 m1) * stranspose (sreplicate (sfromR v3)))))"
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\x1 -> rfromS (ssum (stranspose (sreplicate (sfromR (tproject2 m1))) * sgather (sfromR (tproject1 m1)) (\\[i6, i7] -> [i7, i6])))"
@?= "\\x1 -> rfromS (ssum (stranspose (sreplicate (sfromR (tproject2 m1))) * stranspose (sfromR (tproject1 m1))))"

testMatmul2PP :: Assertion
testMatmul2PP = do
Expand Down
2 changes: 1 addition & 1 deletion test/simplified/TestConvSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ testTomsSlicePP = do
resetVarCounter >> resetIdCounter
let renames = IM.empty
f = codeTomsSlice
(artifactRev, delta) = revArtifactAdapt True f (rreshape [32, 4] t128)
(artifactRev, _delta) = revArtifactAdapt True f (rreshape [32, 4] t128)
printArtifactPretty renames artifactRev
@?= "\\x18 x1 -> let v14 = sreshape (sgather (sfromR m1) (\\[i10, i11] -> [i10, i11])) ; v15 = sreshape (sgather (sfromR m1) (\\[i12, i13] -> [i12, 1 + i13])) ; v16 = sfromIntegral siota ; v19 = sreplicate (ssum (v16 * ssum (stranspose (sreshape (sreplicate (sfromR x18)))))) in rfromS (sscatter (sreshape (v15 * v19)) (\\[i22, i23] -> [i22, i23]) + sscatter (sreshape (v14 * v19)) (\\[i20, i21] -> [i20, 1 + i21]))"
printArtifactPretty renames (simplifyArtifact artifactRev)
Expand Down
8 changes: 4 additions & 4 deletions test/simplified/TestGatherSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,12 @@ testGatherSimpPP22 = do
resetVarCounter
let !t1 = gatherReshape22 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (FTKR [6, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 122
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 604
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 122
resetVarCounter
let !t2 = rreshape @(AstTensor AstMethodLet PrimalSpan) @_ @2 @2 [2, 6]
$ AstVar (FTKR [6, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 122
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 604
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 122

testGatherSimpPP23 :: Assertion
testGatherSimpPP23 = do
Expand All @@ -335,14 +335,14 @@ testGatherSimpPP23 = do
(t * rreplicate0N [6, 2] (rfromIndex0 i))))
$ AstVar (FTKR [6, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 349
length (show (simplifyInlineContract @(TKR 3 Float) t1)) @?= 1119
length (show (simplifyInlineContract @(TKR 3 Float) t1)) @?= 637
resetVarCounter
let !t2 = (\t -> rbuild1 4 (\i ->
rreshape @(AstTensor AstMethodLet PrimalSpan) @_ @2 @2 [2, 6]
(t * rreplicate0N [6, 2] (rfromIndex0 i))))
$ AstVar (FTKR [6, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 349
length (show (simplifyInlineContract @(TKR 3 Float) t2)) @?= 1119
length (show (simplifyInlineContract @(TKR 3 Float) t2)) @?= 637

-- Depending on if and how transpose it desugared, this may or may not result
-- in dozens of nested gathers that should vanish after simplification.
Expand Down
8 changes: 4 additions & 4 deletions test/simplified/TestMnistFCNNR.hs
Original file line number Diff line number Diff line change
Expand Up @@ -773,9 +773,9 @@ testVT2OPP = do
printArtifactPrimalPretty renames artifactRev
@?= "\\x1 -> let m5 = stranspose (sreplicate (scast (ssum (stranspose (sreplicate (sreplicate (sscalar 7.0))) * stranspose (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; m6 = stranspose (sreplicate (scast (ssum (m5 * stranspose (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))))) in rfromS (ssum (m6 * stranspose (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1)))"
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\v7 x1 -> tfromS (let m5 = stranspose (sreplicate (scast (ssum (stranspose (sreplicate (sreplicate (sscalar 7.0))) * stranspose (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; v8 = ssum (sfromR (tproject1 (tproject2 m1)) * stranspose (sreplicate (sfromR v7))) ; m9 = sreplicate (scast v8) ; v10 = scast (ssum (sfromR (tproject1 (tproject2 (tproject1 m1))) * stranspose m9)) in tpair (tpair (tpair (sreplicate (sreplicate (sscalar 7.0)) * stranspose (sreplicate v10), v10), tpair (stranspose (m5 * m9), v8)), tpair (sreplicate (scast (ssum (m5 * sgather (sfromR (tproject1 (tproject2 (tproject1 m1)))) (\\[i11, i12] -> [i12, i11]))) + sfromR (tproject2 (tproject2 (tproject1 m1)))) * stranspose (sreplicate (sfromR v7)), v7)))"
@?= "\\v7 x1 -> tfromS (let m5 = stranspose (sreplicate (scast (ssum (stranspose (sreplicate (sreplicate (sscalar 7.0))) * stranspose (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; v8 = ssum (sfromR (tproject1 (tproject2 m1)) * stranspose (sreplicate (sfromR v7))) ; m9 = sreplicate (scast v8) ; v10 = scast (ssum (sfromR (tproject1 (tproject2 (tproject1 m1))) * stranspose m9)) in tpair (tpair (tpair (sreplicate (sreplicate (sscalar 7.0)) * stranspose (sreplicate v10), v10), tpair (stranspose (m5 * m9), v8)), tpair (sreplicate (scast (ssum (m5 * stranspose (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1)))) * stranspose (sreplicate (sfromR v7)), v7)))"
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\x1 -> rfromS (ssum (stranspose (sreplicate (scast (ssum (stranspose (sreplicate (scast (ssum (stranspose (sreplicate (sreplicate (sscalar 7.0))) * stranspose (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) * stranspose (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))))) * sgather (sfromR (tproject1 (tproject2 m1))) (\\[i17, i18] -> [i18, i17])) + sfromR (tproject2 (tproject2 m1)))"
@?= "\\x1 -> rfromS (ssum (stranspose (sreplicate (scast (ssum (stranspose (sreplicate (scast (ssum (stranspose (sreplicate (sreplicate (sscalar 7.0))) * stranspose (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) * stranspose (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))))) * stranspose (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1)))"

testVT2OPPNonLin :: Assertion
testVT2OPPNonLin = do
Expand Down Expand Up @@ -816,6 +816,6 @@ testVT2OPPNonLin2 = do
printArtifactPrimalPretty renames artifactRevnonLin
@?= "\\x1 -> let v10 = ssum (stranspose (sreplicate (sreplicate (sscalar 7.0))) * stranspose (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))) ; v11 = exp (negate v10) ; v12 = sreplicate (sscalar 1.0) + v11 ; v13 = recip v12 ; m16 = stranspose (sreplicate (scast v13)) ; v17 = scast (ssum (m16 * stranspose (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))) ; v18 = exp (negate v17) ; v19 = sreplicate (sscalar 1.0) + v18 ; v20 = recip v19 ; m23 = stranspose (sreplicate v20) ; v24 = exp (ssum (m23 * stranspose (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x25 = ssum v24 ; v26 = sreplicate (recip x25) in rfromS (v26 * v24)"
printArtifactPretty renames (simplifyArtifact artifactRevnonLin)
@?= "\\v27 x1 -> tfromS (let v13 = recip (sreplicate (sscalar 1.0) + exp (negate (ssum (stranspose (sreplicate (sreplicate (sscalar 7.0))) * sgather (sfromR (tproject1 (tproject1 (tproject1 m1)))) (\\[i37, i38] -> [i38, i37])) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; m16 = stranspose (sreplicate (scast v13)) ; v20 = recip (sreplicate (sscalar 1.0) + exp (negate (scast (ssum (m16 * sgather (sfromR (tproject1 (tproject2 (tproject1 m1)))) (\\[i35, i36] -> [i36, i35]))) + sfromR (tproject2 (tproject2 (tproject1 m1)))))) ; v24 = exp (ssum (stranspose (sreplicate v20) * sgather (sfromR (tproject1 (tproject2 m1))) (\\[i33, i34] -> [i34, i33])) + sfromR (tproject2 (tproject2 m1))) ; x25 = ssum v24 ; v28 = v24 * (sreplicate (negate (recip (x25 * x25)) * sdot0 v24 (sfromR v27)) + sreplicate (recip x25) * sfromR v27) ; v30 = (v20 * (sreplicate (sscalar 1.0) - v20)) * ssum (sfromR (tproject1 (tproject2 m1)) * stranspose (sreplicate v28)) ; m31 = sreplicate (scast v30) ; v32 = (v13 * (sreplicate (sscalar 1.0) - v13)) * scast (ssum (sfromR (tproject1 (tproject2 (tproject1 m1))) * stranspose m31)) in tpair (tpair (tpair (sreplicate (sreplicate (sscalar 7.0)) * stranspose (sreplicate v32), v32), tpair (stranspose (m16 * m31), v30)), tpair (sreplicate v20 * stranspose (sreplicate v28), v28)))"
@?= "\\v27 x1 -> tfromS (let v13 = recip (sreplicate (sscalar 1.0) + exp (negate (ssum (stranspose (sreplicate (sreplicate (sscalar 7.0))) * stranspose (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; m16 = stranspose (sreplicate (scast v13)) ; v20 = recip (sreplicate (sscalar 1.0) + exp (negate (scast (ssum (m16 * stranspose (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1)))))) ; v24 = exp (ssum (stranspose (sreplicate v20) * stranspose (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x25 = ssum v24 ; v28 = v24 * (sreplicate (negate (recip (x25 * x25)) * sdot0 v24 (sfromR v27)) + sreplicate (recip x25) * sfromR v27) ; v30 = (v20 * (sreplicate (sscalar 1.0) - v20)) * ssum (sfromR (tproject1 (tproject2 m1)) * stranspose (sreplicate v28)) ; m31 = sreplicate (scast v30) ; v32 = (v13 * (sreplicate (sscalar 1.0) - v13)) * scast (ssum (sfromR (tproject1 (tproject2 (tproject1 m1))) * stranspose m31)) in tpair (tpair (tpair (sreplicate (sreplicate (sscalar 7.0)) * stranspose (sreplicate v32), v32), tpair (stranspose (m16 * m31), v30)), tpair (sreplicate v20 * stranspose (sreplicate v28), v28)))"
printArtifactPrimalPretty renames (simplifyArtifact artifactRevnonLin)
@?= "\\x1 -> rfromS (let v24 = exp (ssum (stranspose (sreplicate (recip (sreplicate (sscalar 1.0) + exp (negate (scast (ssum (stranspose (sreplicate (scast (recip (sreplicate (sscalar 1.0) + exp (negate (ssum (stranspose (sreplicate (sreplicate (sscalar 7.0))) * stranspose (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))))))))) * stranspose (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1)))))))) * sgather (sfromR (tproject1 (tproject2 m1))) (\\[i47, i48] -> [i48, i47])) + sfromR (tproject2 (tproject2 m1))) in sreplicate (recip (ssum v24)) * v24)"
@?= "\\x1 -> rfromS (let v24 = exp (ssum (stranspose (sreplicate (recip (sreplicate (sscalar 1.0) + exp (negate (scast (ssum (stranspose (sreplicate (scast (recip (sreplicate (sscalar 1.0) + exp (negate (ssum (stranspose (sreplicate (sreplicate (sscalar 7.0))) * stranspose (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))))))))) * stranspose (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1)))))))) * stranspose (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) in sreplicate (recip (ssum v24)) * v24)"
Loading

0 comments on commit 6e5852b

Please sign in to comment.