Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lift non-immediate expressions out of case values for the Nockma backend #3010

Merged
merged 4 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/Commands/Dev/Core/Compile/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ runTreePipeline pa@PipelineArg {..} = do
r <-
runReader entryPoint
. runError @JuvixError
. coreToTree Core.IdentityTrans
. coreToTree Core.IdentityTrans []
$ _pipelineArgModule
tab' <- getRight r
let code = Tree.ppPrint tab' tab'
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ data TransformationId
| IdentityTrans
| UnrollRecursion
| ComputeTypeInfo
| ComputeCaseANF
| MatchToCase
| EtaExpandApps
| DisambiguateNames
Expand Down Expand Up @@ -91,6 +92,7 @@ instance TransformationId' TransformationId where
IntToPrimInt -> strIntToPrimInt
ConvertBuiltinTypes -> strConvertBuiltinTypes
ComputeTypeInfo -> strComputeTypeInfo
ComputeCaseANF -> strComputeCaseANF
UnrollRecursion -> strUnrollRecursion
DisambiguateNames -> strDisambiguateNames
CombineInfoTables -> strCombineInfoTables
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ strConvertBuiltinTypes = "convert-builtin-types"
strComputeTypeInfo :: Text
strComputeTypeInfo = "compute-type-info"

strComputeCaseANF :: Text
strComputeCaseANF = "compute-case-anf"

strUnrollRecursion :: Text
strUnrollRecursion = "unroll-recursion"

Expand Down
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Core/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,9 @@ toStripped checkId = mapReader fromEntryPoint . applyTransformations (toStripped
-- | Perform transformations on stored Core necessary before the translation to VampIR
toVampIR :: (Members '[Error JuvixError, Reader EntryPoint] r) => Module -> Sem r Module
toVampIR = mapReader fromEntryPoint . applyTransformations toVampIRTransformations

extraAnomaTransformations :: [TransformationId]
extraAnomaTransformations = [ComputeCaseANF]

applyExtraTransformations :: (Members '[Error JuvixError, Reader EntryPoint] r) => [TransformationId] -> Module -> Sem r Module
applyExtraTransformations transforms = mapReader fromEntryPoint . applyTransformations transforms
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import Juvix.Compiler.Core.Transformation.Check.Exec
import Juvix.Compiler.Core.Transformation.Check.Rust
import Juvix.Compiler.Core.Transformation.Check.VampIR
import Juvix.Compiler.Core.Transformation.CombineInfoTables (combineInfoTables)
import Juvix.Compiler.Core.Transformation.ComputeCaseANF
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo
import Juvix.Compiler.Core.Transformation.ConvertBuiltinTypes
import Juvix.Compiler.Core.Transformation.DisambiguateNames
Expand Down Expand Up @@ -72,6 +73,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
IntToPrimInt -> return . intToPrimInt
ConvertBuiltinTypes -> return . convertBuiltinTypes
ComputeTypeInfo -> return . computeTypeInfo
ComputeCaseANF -> return . computeCaseANF
UnrollRecursion -> unrollRecursion
MatchToCase -> mapError (JuvixError @CoreError) . matchToCase
EtaExpandApps -> return . etaExpansionApps
Expand Down
62 changes: 62 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/ComputeCaseANF.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
module Juvix.Compiler.Core.Transformation.ComputeCaseANF (computeCaseANF) where

-- A transformation which lifts out non-immediate values matched on in case
-- expressions by introducing let-bindings for them. In essence, this is a
-- partial ANF transformation for case expressions only.
--
-- For example, transforms
-- ```
-- case f x of { c y := y + x; d y := y }
-- ```
-- to
-- ```
-- let z := f x in case z of { c y := y + x; d y := y }
-- ```
-- This transformation is needed for the Nockma backend.

import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.TypeInfo qualified as Info
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo (computeNodeTypeInfo)

convertNode :: Module -> Node -> Node
convertNode md = Info.removeTypeInfo . rmapL go . computeNodeTypeInfo md
where
go :: ([BinderChange] -> Node -> Node) -> BinderList Binder -> Node -> Node
go recur bl node = case node of
NCase Case {..}
| not (isImmediate md _caseValue) ->
mkLet _caseInfo b val' $
NCase
Case
{ _caseValue = mkVar' 0,
_caseBranches = map goCaseBranch _caseBranches,
_caseDefault = fmap (go (recur . (BCAdd 1 :)) bl) _caseDefault,
_caseInfo,
_caseInductive
}
where
val' = go recur bl _caseValue
b = Binder "case_value" Nothing ty
ty = Info.getNodeType _caseValue

goCaseBranch :: CaseBranch -> CaseBranch
goCaseBranch CaseBranch {..} =
CaseBranch
{ _caseBranchBody =
go
(recur . ((BCAdd 1 : map BCKeep _caseBranchBinders) ++))
(BL.prependRev _caseBranchBinders bl)
_caseBranchBody,
_caseBranchTag,
_caseBranchInfo,
_caseBranchBindersNum,
_caseBranchBinders
}
_ ->
recur [] node

computeCaseANF :: Module -> Module
computeCaseANF md =
mapAllNodes (convertNode md) md
26 changes: 15 additions & 11 deletions src/Juvix/Compiler/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ upToTree ::
(Members '[HighlightBuilder, Reader Parser.ParserResult, Reader EntryPoint, Reader Store.ModuleTable, Files, NameIdGen, Error JuvixError] r) =>
Sem r Tree.InfoTable
upToTree =
upToStoredCore >>= \Core.CoreResult {..} -> storedCoreToTree Core.IdentityTrans _coreResultModule
upToStoredCore >>= \Core.CoreResult {..} -> storedCoreToTree Core.IdentityTrans [] _coreResultModule

upToAsm ::
(Members '[HighlightBuilder, Reader Parser.ParserResult, Reader EntryPoint, Reader Store.ModuleTable, Files, NameIdGen, Error JuvixError] r) =>
Expand Down Expand Up @@ -226,17 +226,21 @@ upToCoreTypecheck = do
storedCoreToTree ::
(Members '[Error JuvixError, Reader EntryPoint] r) =>
Core.TransformationId ->
[Core.TransformationId] ->
Core.Module ->
Sem r Tree.InfoTable
storedCoreToTree checkId md = do
storedCoreToTree checkId extraTransforms md = do
fsize <- asks (^. entryPointFieldSize)
Tree.fromCore . Stripped.fromCore fsize . Core.computeCombinedInfoTable <$> Core.toStripped checkId md
Tree.fromCore
. Stripped.fromCore fsize
. Core.computeCombinedInfoTable
<$> (Core.toStripped checkId md >>= Core.applyExtraTransformations extraTransforms)

storedCoreToAnoma :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r NockmaTree.AnomaResult
storedCoreToAnoma = storedCoreToTree Core.CheckAnoma >=> treeToAnoma
storedCoreToAnoma = storedCoreToTree Core.CheckAnoma Core.extraAnomaTransformations >=> treeToAnoma

storedCoreToAsm :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Asm.InfoTable
storedCoreToAsm = storedCoreToTree Core.CheckExec >=> treeToAsm
storedCoreToAsm = storedCoreToTree Core.CheckExec [] >=> treeToAsm

storedCoreToReg :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Reg.InfoTable
storedCoreToReg = storedCoreToAsm >=> asmToReg
Expand All @@ -245,13 +249,13 @@ storedCoreToMiniC :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.
storedCoreToMiniC = storedCoreToAsm >=> asmToMiniC

storedCoreToRust :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Rust.Result
storedCoreToRust = storedCoreToTree Core.CheckRust >=> treeToReg >=> regToRust
storedCoreToRust = storedCoreToTree Core.CheckRust [] >=> treeToReg >=> regToRust

storedCoreToRiscZeroRust :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Rust.Result
storedCoreToRiscZeroRust = storedCoreToTree Core.CheckRust >=> treeToReg >=> regToRiscZeroRust
storedCoreToRiscZeroRust = storedCoreToTree Core.CheckRust [] >=> treeToReg >=> regToRiscZeroRust

storedCoreToCasm :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Casm.Result
storedCoreToCasm = local (set entryPointFieldSize cairoFieldSize) . storedCoreToTree Core.CheckCairo >=> treeToCasm
storedCoreToCasm = local (set entryPointFieldSize cairoFieldSize) . storedCoreToTree Core.CheckCairo [] >=> treeToCasm

storedCoreToCairo :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Cairo.Result
storedCoreToCairo = storedCoreToCasm >=> casmToCairo
Expand All @@ -263,8 +267,8 @@ storedCoreToVampIR = Core.toVampIR >=> VampIR.fromCore . Core.computeCombinedInf
-- Workflows from Core
--------------------------------------------------------------------------------

coreToTree :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.TransformationId -> Core.Module -> Sem r Tree.InfoTable
coreToTree checkId = Core.toStored >=> storedCoreToTree checkId
coreToTree :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.TransformationId -> [Core.TransformationId] -> Core.Module -> Sem r Tree.InfoTable
coreToTree checkId extraTransforms = Core.toStored >=> storedCoreToTree checkId extraTransforms

coreToAsm :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Asm.InfoTable
coreToAsm = Core.toStored >=> storedCoreToAsm
Expand All @@ -279,7 +283,7 @@ coreToCairo :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module
coreToCairo = Core.toStored >=> storedCoreToCairo

coreToAnoma :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r NockmaTree.AnomaResult
coreToAnoma = coreToTree Core.CheckAnoma >=> treeToAnoma
coreToAnoma = coreToTree Core.CheckAnoma Core.extraAnomaTransformations >=> treeToAnoma

coreToRust :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Rust.Result
coreToRust = Core.toStored >=> storedCoreToRust
Expand Down
Loading