Skip to content

Commit

Permalink
Lift non-immediate expressions out of case values for the Nockma back…
Browse files Browse the repository at this point in the history
…end (#3010)

Implements a transformation `compute-case-anf` which lifts out
non-immediate values matched on in case expressions by introducing
let-bindings for them. In essence, this is a partial ANF transformation
for case expressions only.

For example, transforms
```
case f x of { c y := y + x; d y := y }
```
to
```
let z := f x in case z of { c y := y + x; d y := y }
```
This transformation is needed to avoid duplication of values matched on
in case-expressions in the Nockma backend.
  • Loading branch information
lukaszcz authored Sep 9, 2024
1 parent f47b9b0 commit 7167cb3
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 12 deletions.
2 changes: 1 addition & 1 deletion app/Commands/Dev/Core/Compile/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ runTreePipeline pa@PipelineArg {..} = do
r <-
runReader entryPoint
. runError @JuvixError
. coreToTree Core.IdentityTrans
. coreToTree Core.IdentityTrans []
$ _pipelineArgModule
tab' <- getRight r
let code = Tree.ppPrint tab' tab'
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ data TransformationId
| IdentityTrans
| UnrollRecursion
| ComputeTypeInfo
| ComputeCaseANF
| MatchToCase
| EtaExpandApps
| DisambiguateNames
Expand Down Expand Up @@ -91,6 +92,7 @@ instance TransformationId' TransformationId where
IntToPrimInt -> strIntToPrimInt
ConvertBuiltinTypes -> strConvertBuiltinTypes
ComputeTypeInfo -> strComputeTypeInfo
ComputeCaseANF -> strComputeCaseANF
UnrollRecursion -> strUnrollRecursion
DisambiguateNames -> strDisambiguateNames
CombineInfoTables -> strCombineInfoTables
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ strConvertBuiltinTypes = "convert-builtin-types"
strComputeTypeInfo :: Text
strComputeTypeInfo = "compute-type-info"

strComputeCaseANF :: Text
strComputeCaseANF = "compute-case-anf"

strUnrollRecursion :: Text
strUnrollRecursion = "unroll-recursion"

Expand Down
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Core/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,9 @@ toStripped checkId = mapReader fromEntryPoint . applyTransformations (toStripped
-- | Perform transformations on stored Core necessary before the translation to VampIR
toVampIR :: (Members '[Error JuvixError, Reader EntryPoint] r) => Module -> Sem r Module
toVampIR = mapReader fromEntryPoint . applyTransformations toVampIRTransformations

extraAnomaTransformations :: [TransformationId]
extraAnomaTransformations = [ComputeCaseANF]

applyExtraTransformations :: (Members '[Error JuvixError, Reader EntryPoint] r) => [TransformationId] -> Module -> Sem r Module
applyExtraTransformations transforms = mapReader fromEntryPoint . applyTransformations transforms
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import Juvix.Compiler.Core.Transformation.Check.Exec
import Juvix.Compiler.Core.Transformation.Check.Rust
import Juvix.Compiler.Core.Transformation.Check.VampIR
import Juvix.Compiler.Core.Transformation.CombineInfoTables (combineInfoTables)
import Juvix.Compiler.Core.Transformation.ComputeCaseANF
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo
import Juvix.Compiler.Core.Transformation.ConvertBuiltinTypes
import Juvix.Compiler.Core.Transformation.DisambiguateNames
Expand Down Expand Up @@ -72,6 +73,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
IntToPrimInt -> return . intToPrimInt
ConvertBuiltinTypes -> return . convertBuiltinTypes
ComputeTypeInfo -> return . computeTypeInfo
ComputeCaseANF -> return . computeCaseANF
UnrollRecursion -> unrollRecursion
MatchToCase -> mapError (JuvixError @CoreError) . matchToCase
EtaExpandApps -> return . etaExpansionApps
Expand Down
62 changes: 62 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/ComputeCaseANF.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
module Juvix.Compiler.Core.Transformation.ComputeCaseANF (computeCaseANF) where

-- A transformation which lifts out non-immediate values matched on in case
-- expressions by introducing let-bindings for them. In essence, this is a
-- partial ANF transformation for case expressions only.
--
-- For example, transforms
-- ```
-- case f x of { c y := y + x; d y := y }
-- ```
-- to
-- ```
-- let z := f x in case z of { c y := y + x; d y := y }
-- ```
-- This transformation is needed for the Nockma backend.

import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.TypeInfo qualified as Info
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo (computeNodeTypeInfo)

convertNode :: Module -> Node -> Node
convertNode md = Info.removeTypeInfo . rmapL go . computeNodeTypeInfo md
where
go :: ([BinderChange] -> Node -> Node) -> BinderList Binder -> Node -> Node
go recur bl node = case node of
NCase Case {..}
| not (isImmediate md _caseValue) ->
mkLet _caseInfo b val' $
NCase
Case
{ _caseValue = mkVar' 0,
_caseBranches = map goCaseBranch _caseBranches,
_caseDefault = fmap (go (recur . (BCAdd 1 :)) bl) _caseDefault,
_caseInfo,
_caseInductive
}
where
val' = go recur bl _caseValue
b = Binder "case_value" Nothing ty
ty = Info.getNodeType _caseValue

goCaseBranch :: CaseBranch -> CaseBranch
goCaseBranch CaseBranch {..} =
CaseBranch
{ _caseBranchBody =
go
(recur . ((BCAdd 1 : map BCKeep _caseBranchBinders) ++))
(BL.prependRev _caseBranchBinders bl)
_caseBranchBody,
_caseBranchTag,
_caseBranchInfo,
_caseBranchBindersNum,
_caseBranchBinders
}
_ ->
recur [] node

computeCaseANF :: Module -> Module
computeCaseANF md =
mapAllNodes (convertNode md) md
26 changes: 15 additions & 11 deletions src/Juvix/Compiler/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ upToTree ::
(Members '[HighlightBuilder, Reader Parser.ParserResult, Reader EntryPoint, Reader Store.ModuleTable, Files, NameIdGen, Error JuvixError] r) =>
Sem r Tree.InfoTable
upToTree =
upToStoredCore >>= \Core.CoreResult {..} -> storedCoreToTree Core.IdentityTrans _coreResultModule
upToStoredCore >>= \Core.CoreResult {..} -> storedCoreToTree Core.IdentityTrans [] _coreResultModule

upToAsm ::
(Members '[HighlightBuilder, Reader Parser.ParserResult, Reader EntryPoint, Reader Store.ModuleTable, Files, NameIdGen, Error JuvixError] r) =>
Expand Down Expand Up @@ -226,17 +226,21 @@ upToCoreTypecheck = do
storedCoreToTree ::
(Members '[Error JuvixError, Reader EntryPoint] r) =>
Core.TransformationId ->
[Core.TransformationId] ->
Core.Module ->
Sem r Tree.InfoTable
storedCoreToTree checkId md = do
storedCoreToTree checkId extraTransforms md = do
fsize <- asks (^. entryPointFieldSize)
Tree.fromCore . Stripped.fromCore fsize . Core.computeCombinedInfoTable <$> Core.toStripped checkId md
Tree.fromCore
. Stripped.fromCore fsize
. Core.computeCombinedInfoTable
<$> (Core.toStripped checkId md >>= Core.applyExtraTransformations extraTransforms)

storedCoreToAnoma :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r NockmaTree.AnomaResult
storedCoreToAnoma = storedCoreToTree Core.CheckAnoma >=> treeToAnoma
storedCoreToAnoma = storedCoreToTree Core.CheckAnoma Core.extraAnomaTransformations >=> treeToAnoma

storedCoreToAsm :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Asm.InfoTable
storedCoreToAsm = storedCoreToTree Core.CheckExec >=> treeToAsm
storedCoreToAsm = storedCoreToTree Core.CheckExec [] >=> treeToAsm

storedCoreToReg :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Reg.InfoTable
storedCoreToReg = storedCoreToAsm >=> asmToReg
Expand All @@ -245,13 +249,13 @@ storedCoreToMiniC :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.
storedCoreToMiniC = storedCoreToAsm >=> asmToMiniC

storedCoreToRust :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Rust.Result
storedCoreToRust = storedCoreToTree Core.CheckRust >=> treeToReg >=> regToRust
storedCoreToRust = storedCoreToTree Core.CheckRust [] >=> treeToReg >=> regToRust

storedCoreToRiscZeroRust :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Rust.Result
storedCoreToRiscZeroRust = storedCoreToTree Core.CheckRust >=> treeToReg >=> regToRiscZeroRust
storedCoreToRiscZeroRust = storedCoreToTree Core.CheckRust [] >=> treeToReg >=> regToRiscZeroRust

storedCoreToCasm :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Casm.Result
storedCoreToCasm = local (set entryPointFieldSize cairoFieldSize) . storedCoreToTree Core.CheckCairo >=> treeToCasm
storedCoreToCasm = local (set entryPointFieldSize cairoFieldSize) . storedCoreToTree Core.CheckCairo [] >=> treeToCasm

storedCoreToCairo :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Cairo.Result
storedCoreToCairo = storedCoreToCasm >=> casmToCairo
Expand All @@ -263,8 +267,8 @@ storedCoreToVampIR = Core.toVampIR >=> VampIR.fromCore . Core.computeCombinedInf
-- Workflows from Core
--------------------------------------------------------------------------------

coreToTree :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.TransformationId -> Core.Module -> Sem r Tree.InfoTable
coreToTree checkId = Core.toStored >=> storedCoreToTree checkId
coreToTree :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.TransformationId -> [Core.TransformationId] -> Core.Module -> Sem r Tree.InfoTable
coreToTree checkId extraTransforms = Core.toStored >=> storedCoreToTree checkId extraTransforms

coreToAsm :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Asm.InfoTable
coreToAsm = Core.toStored >=> storedCoreToAsm
Expand All @@ -279,7 +283,7 @@ coreToCairo :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module
coreToCairo = Core.toStored >=> storedCoreToCairo

coreToAnoma :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r NockmaTree.AnomaResult
coreToAnoma = coreToTree Core.CheckAnoma >=> treeToAnoma
coreToAnoma = coreToTree Core.CheckAnoma Core.extraAnomaTransformations >=> treeToAnoma

coreToRust :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Rust.Result
coreToRust = Core.toStored >=> storedCoreToRust
Expand Down

0 comments on commit 7167cb3

Please sign in to comment.