Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structured temporary stack manipulation in JuvixAsm #2554

Merged
merged 3 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Asm/Data/InfoTable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ data FunctionInfo = FunctionInfo
-- (_functionType))` only if it is 0 (the "function" takes zero arguments)
-- and the result is a function.
_functionArgsNum :: Int,
-- | length _functionArgNames == _functionArgsNum
_functionArgNames :: [Maybe Text],
_functionType :: Type,
_functionMaxValueStackHeight :: Int,
_functionMaxTempStackHeight :: Int,
Expand All @@ -39,6 +41,8 @@ data ConstructorInfo = ConstructorInfo
-- (_constructorType))`. It is stored separately mainly for the benefit of
-- the interpreter (so it does not have to recompute it every time).
_constructorArgsNum :: Int,
-- | length _constructorArgNames == _constructorArgsNum
_constructorArgNames :: [Maybe Text],
-- | Constructor types are assumed to be fully uncurried, i.e., `uncurryType
-- _constructorType == _constructorType`
_constructorType :: Type,
Expand Down
8 changes: 4 additions & 4 deletions src/Juvix/Compiler/Asm/Extra/Memory.hs
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ getDirectRefType :: DirectRef -> Memory -> Maybe Type
getDirectRefType dr mem = case dr of
StackRef ->
topValueStack 0 mem
ArgRef off ->
getArgumentType off mem
TempRef off ->
bottomTempStack off mem
ArgRef OffsetRef {..} ->
getArgumentType _offsetRefOffset mem
TempRef OffsetRef {..} ->
bottomTempStack _offsetRefOffset mem

getValueType' :: (Member (Error AsmError) r) => Maybe Location -> InfoTable -> Memory -> Value -> Sem r Type
getValueType' loc tab mem = \case
Expand Down
60 changes: 44 additions & 16 deletions src/Juvix/Compiler/Asm/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ data RecursorSig m r a = RecursorSig
{ _recursorInfoTable :: InfoTable,
_recurseInstr :: m -> CmdInstr -> Sem r a,
_recurseBranch :: m -> CmdBranch -> [a] -> [a] -> Sem r a,
_recurseCase :: m -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a
_recurseCase :: m -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a,
_recurseSave :: m -> CmdSave -> [a] -> Sem r a
}

makeLenses ''RecursorSig
Expand All @@ -43,6 +44,8 @@ recurse' sig = go True
goNextCmd isTail (x ^. (cmdBranchInfo . commandInfoLocation)) (goBranch (isTail && null t) mem x) t
Case x ->
goNextCmd isTail (x ^. (cmdCaseInfo . commandInfoLocation)) (goCase (isTail && null t) mem x) t
Save x ->
goNextCmd isTail (x ^. (cmdSaveInfo . commandInfoLocation)) (goSave (isTail && null t) mem x) t

goNextCmd :: Bool -> Maybe Location -> Sem r (Memory, a) -> Code -> Sem r (Memory, [a])
goNextCmd isTail loc mp t = do
Expand Down Expand Up @@ -104,16 +107,6 @@ recurse' sig = go True
throw $
AsmError loc "popping empty value stack"
return (popValueStack 1 mem)
PushTemp -> do
when (null (mem ^. memoryValueStack)) $
throw $
AsmError loc "popping empty value stack"
return $ pushTempStack (topValueStack' 0 mem) (popValueStack 1 mem)
PopTemp -> do
when (null (mem ^. memoryTempStack)) $
throw $
AsmError loc "popping empty temporary stack"
return $ popTempStack 1 mem
Trace ->
return mem
Dump ->
Expand Down Expand Up @@ -275,6 +268,27 @@ recurse' sig = go True
where
loc = cmd ^. (cmdCaseInfo . commandInfoLocation)

goSave :: Bool -> Memory -> CmdSave -> Sem r (Memory, a)
goSave isTail mem cmd@CmdSave {..} = do
when (null (mem ^. memoryValueStack)) $
throw $
AsmError loc "popping empty value stack"
let mem1 = pushTempStack (topValueStack' 0 mem) (popValueStack 1 mem)
(mem2, a) <- go isTail mem1 _cmdSaveCode
a' <- (sig ^. recurseSave) mem cmd a
when (not isTail && _cmdSaveIsTail) $
throw $
AsmError loc "'tsave' not in tail position"
when (isTail && not _cmdSaveIsTail) $
throw $
AsmError loc "'save' in tail position"
when (not isTail && null (mem2 ^. memoryTempStack)) $
throw $
AsmError loc "popping empty temporary stack"
return (if isTail then mem2 else popTempStack 1 mem2, a')
where
loc = _cmdSaveInfo ^. commandInfoLocation

checkBranchInvariant :: Int -> Maybe Location -> Memory -> Memory -> Sem r ()
checkBranchInvariant k loc mem mem' = do
unless (length (mem' ^. memoryValueStack) == length (mem ^. memoryValueStack) + k) $
Expand Down Expand Up @@ -320,6 +334,8 @@ recurseS' sig = go
goNextCmd (goBranch si x) t
Case x ->
goNextCmd (goCase si x) t
Save x ->
goNextCmd (goSave si x) t

goNextCmd :: Sem r (StackInfo, a) -> Code -> Sem r (StackInfo, [a])
goNextCmd mp t = do
Expand Down Expand Up @@ -362,10 +378,6 @@ recurseS' sig = go
return (stackInfoPushValueStack 1 si)
Pop -> do
return (stackInfoPopValueStack 1 si)
PushTemp -> do
return $ stackInfoPushTempStack 1 (stackInfoPopValueStack 1 si)
PopTemp -> do
return $ stackInfoPopTempStack 1 si
Trace ->
return si
Dump ->
Expand Down Expand Up @@ -436,6 +448,14 @@ recurseS' sig = go
where
loc = cmd ^. (cmdCaseInfo . commandInfoLocation)

goSave :: StackInfo -> CmdSave -> Sem r (StackInfo, a)
goSave si cmd@CmdSave {..} = do
let si1 = stackInfoPushTempStack 1 (stackInfoPopValueStack 1 si)
(si2, c) <- go si1 _cmdSaveCode
c' <- (sig ^. recurseSave) si cmd c
let si' = if _cmdSaveIsTail then si2 else stackInfoPopTempStack 1 si2
return (si', c')

checkStackInfo :: Maybe Location -> StackInfo -> StackInfo -> Sem r ()
checkStackInfo loc si1 si2 =
when (si1 /= si2) $
Expand Down Expand Up @@ -463,7 +483,8 @@ data FoldSig m r a = FoldSig
_foldAdjust :: a -> a,
_foldInstr :: m -> CmdInstr -> a -> Sem r a,
_foldBranch :: m -> CmdBranch -> a -> a -> a -> Sem r a,
_foldCase :: m -> CmdCase -> [a] -> Maybe a -> a -> Sem r a
_foldCase :: m -> CmdCase -> [a] -> Maybe a -> a -> Sem r a,
_foldSave :: m -> CmdSave -> a -> a -> Sem r a
}

makeLenses ''FoldSig
Expand Down Expand Up @@ -499,6 +520,13 @@ foldS' sig si code acc = do
Just d -> Just <$> compose' d a'
Nothing -> return Nothing
(sig ^. foldCase) s cmd as ad a
),
_recurseSave = \s cmd br ->
return
( \a -> do
let a' = (sig ^. foldAdjust) a
a'' <- compose' br a'
(sig ^. foldSave) s cmd a'' a
)
}

Expand Down
19 changes: 11 additions & 8 deletions src/Juvix/Compiler/Asm/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ runCodeR infoTable funInfo = goCode (funInfo ^. functionCode) >> popLastValueSta
_ -> case def of
Just x -> goCode x
Nothing -> runtimeError "no matching branch"
Save CmdSave {..} -> do
registerLocation (_cmdSaveInfo ^. commandInfoLocation)
v <- popValueStack
pushTempStack v
if
| _cmdSaveIsTail ->
goCode _cmdSaveCode
| otherwise ->
goCode _cmdSaveCode >> popTempStack >> goCode cont

goInstr :: (Member Runtime r) => Maybe Location -> Instruction -> Code -> Sem r ()
goInstr loc instr cont = case instr of
Expand Down Expand Up @@ -109,12 +118,6 @@ runCodeR infoTable funInfo = goCode (funInfo ^. functionCode) >> popLastValueSta
goCode cont
Pop ->
popValueStack >> goCode cont
PushTemp -> do
v <- popValueStack
pushTempStack v
goCode cont
PopTemp ->
popTempStack >> goCode cont
Trace -> do
v <- topValueStack
logMessage (printVal v)
Expand Down Expand Up @@ -225,8 +228,8 @@ runCodeR infoTable funInfo = goCode (funInfo ^. functionCode) >> popLastValueSta
getDirectRef :: (Member Runtime r) => DirectRef -> Sem r Val
getDirectRef = \case
StackRef -> topValueStack
ArgRef off -> readArg off
TempRef off -> readTemp off
ArgRef OffsetRef {..} -> readArg _offsetRefOffset
TempRef OffsetRef {..} -> readTemp _offsetRefOffset

popLastValueStack :: (Member Runtime r) => Sem r Val
popLastValueStack = do
Expand Down
32 changes: 25 additions & 7 deletions src/Juvix/Compiler/Asm/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,28 @@ data DirectRef
StackRef
| -- | ArgRef references an argument in the argument area (0-based offsets).
-- JVA code: 'arg[<offset>]'.
ArgRef Offset
ArgRef OffsetRef
| -- | TempRef references a value in the temporary area (0-based offsets). JVA
-- code: 'tmp[<offset>]'.
TempRef Offset
TempRef OffsetRef

data OffsetRef = OffsetRef
{ _offsetRefOffset :: Offset,
_offsetRefName :: Maybe Text
}

-- | Constructor field reference. JVA code: '<dref>.<tag>[<offset>]'
data Field = Field
{ -- | tag of the constructor being referenced
{ _fieldName :: Maybe Text,
-- | tag of the constructor being referenced
_fieldTag :: Tag,
-- | location where the data is stored
_fieldRef :: DirectRef,
_fieldOffset :: Offset
}

makeLenses ''Field
makeLenses ''OffsetRef

-- | Function call type
data CallType = CallFun Symbol | CallClosure
Expand All @@ -80,10 +87,6 @@ data Instruction
Push Value
| -- | Pop the stack. JVA opcode: 'pop'.
Pop
| -- | Push the top of the value stack onto the temporary stack, pop the value
-- stack. Used to implement Core.Let and Core.Case. JVA opcodes: 'pusht', 'popt'.
PushTemp
| PopTemp
| -- | Print a debug log of the object on top of the stack. Does not pop the
-- stack. JVA opcode: 'trace'.
Trace
Expand Down Expand Up @@ -220,6 +223,13 @@ data Command
-- JVA code: 'case <ind> { <tag>: {<code>} ... <tag>: {<code>} default: {<code>} }'
-- (any branch may be omitted).
Case CmdCase
| -- | Push the top of the value stack onto the temporary stack, pop the value
-- stack, execute the nested code, and if not tail recursive then pop the
-- temporary stack afterwards. Used to implement Core.Let and Core.Case. JVA
-- codes: 'save {<code>}', 'save <name> {<code>}', 'tsave {<code>}', 'tsave
-- <name> {<code>}'. The 'tsave' version does not pop the temporary stack
-- after executing '<code>' (which is supposed to return).
Save CmdSave

newtype CommandInfo = CommandInfo
{ _commandInfoLocation :: Maybe Location
Expand Down Expand Up @@ -251,6 +261,13 @@ data CaseBranch = CaseBranch
_caseBranchCode :: Code
}

data CmdSave = CmdSave
{ _cmdSaveInfo :: CommandInfo,
_cmdSaveIsTail :: Bool,
_cmdSaveName :: Maybe Text,
_cmdSaveCode :: Code
}

-- | `Code` corresponds to JuvixAsm code for a single function.
type Code = [Command]

Expand All @@ -263,3 +280,4 @@ makeLenses ''CmdInstr
makeLenses ''CmdBranch
makeLenses ''CmdCase
makeLenses ''CaseBranch
makeLenses ''CmdSave
22 changes: 15 additions & 7 deletions src/Juvix/Compiler/Asm/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,16 @@ instance PrettyCode Type where
TyFun x ->
ppCode x

ppOffsetRef :: Text -> OffsetRef -> Sem r (Doc Ann)
ppOffsetRef str OffsetRef {..} =
return $ maybe (variable str <> lbracket <> integer _offsetRefOffset <> rbracket) variable _offsetRefName

instance PrettyCode DirectRef where
ppCode :: DirectRef -> Sem r (Doc Ann)
ppCode = \case
StackRef -> return $ variable Str.dollar
ArgRef off -> return $ variable Str.arg <> lbracket <> integer off <> rbracket
TempRef off -> return $ variable Str.tmp <> lbracket <> integer off <> rbracket
ArgRef roff -> ppOffsetRef Str.arg roff
TempRef roff -> ppOffsetRef Str.tmp roff

instance PrettyCode Field where
ppCode :: (Member (Reader Options) r) => Field -> Sem r (Doc Ann)
Expand Down Expand Up @@ -273,8 +277,6 @@ instance PrettyCode Instruction where
StrToInt -> return $ primitive Str.instrStrToInt
Push val -> (primitive Str.instrPush <+>) <$> ppCode val
Pop -> return $ primitive Str.instrPop
PushTemp -> return $ primitive Str.instrPusht
PopTemp -> return $ primitive Str.instrPopt
Trace -> return $ primitive Str.instrTrace
Dump -> return $ primitive Str.instrDump
Failure -> return $ primitive Str.instrFailure
Expand Down Expand Up @@ -336,6 +338,10 @@ instance PrettyCode Command where
return $ brs ++ [d]
Nothing -> return brs
return $ primitive Str.case_ <+> name <+> braces' (vsep brs')
Save CmdSave {..} -> do
c <- ppCodeCode _cmdSaveCode
let s = if _cmdSaveIsTail then Str.tsave else Str.save
return $ primitive s <+> (maybe mempty ((<> space) . variable) _cmdSaveName) <> braces' c

instance (PrettyCode a) => PrettyCode [a] where
ppCode x = do
Expand All @@ -344,13 +350,15 @@ instance (PrettyCode a) => PrettyCode [a] where

instance PrettyCode FunctionInfo where
ppCode FunctionInfo {..} = do
argtys <- mapM ppCode (typeArgs _functionType)
targetty <- ppCode (typeTarget _functionType)
argtys <- mapM ppCode (take _functionArgsNum (typeArgs _functionType))
let argnames = map (fmap variable) _functionArgNames
args = zipWithExact (\mn ty -> maybe mempty (\n -> n <+> colon <> space) mn <> ty) argnames argtys
targetty <- ppCode (if _functionArgsNum == 0 then _functionType else typeTarget _functionType)
c <- ppCodeCode _functionCode
return $
keyword Str.function
<+> annotate (AnnKind KNameFunction) (pretty (quoteAsmFunName $ quoteAsmName _functionName))
<> encloseSep lparen rparen ", " argtys
<> encloseSep lparen rparen ", " args
<+> colon
<+> targetty
<+> braces' c
Expand Down
19 changes: 17 additions & 2 deletions src/Juvix/Compiler/Asm/Transformation/Prealloc.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ computeCodePrealloc maxArgsNum tab code = prealloc <$> foldS sig code (0, [])
_foldAdjust = second (const []),
_foldInstr = const goInstr,
_foldBranch = const goBranch,
_foldCase = const goCase
_foldCase = const goCase,
_foldSave = const goSave
}

goInstr :: CmdInstr -> (Int, Code) -> Sem r (Int, Code)
Expand Down Expand Up @@ -77,6 +78,15 @@ computeCodePrealloc maxArgsNum tab code = prealloc <$> foldS sig code (0, [])
_cmdCaseDefault = fmap prealloc md
}

goSave :: CmdSave -> (Int, Code) -> (Int, Code) -> Sem r (Int, Code)
goSave cmd (k, br) (_, c) = return (k, cmd' : c)
where
cmd' =
Save
cmd
{ _cmdSaveCode = br
}

prealloc :: (Int, Code) -> Code
prealloc (0, c) = c
prealloc (n, c) = mkInstr (Prealloc (InstrPrealloc n)) : c
Expand All @@ -100,7 +110,8 @@ checkCodePrealloc maxArgsNum tab code = do
_foldAdjust = id,
_foldInstr = const goInstr,
_foldBranch = const goBranch,
_foldCase = const goCase
_foldCase = const goCase,
_foldSave = const goSave
}

goInstr :: CmdInstr -> (Int -> Int) -> Sem r (Int -> Int)
Expand Down Expand Up @@ -145,6 +156,10 @@ checkCodePrealloc maxArgsNum tab code = do
k' = min (minimum ks) (fromMaybe k kd)
in cont k'

goSave :: CmdSave -> (Int -> Int) -> (Int -> Int) -> Sem r (Int -> Int)
goSave _ br cont =
return $ cont . br

checkPrealloc :: Options -> InfoTable -> Bool
checkPrealloc opts tab =
case run $ runError $ runReader opts sb of
Expand Down
5 changes: 5 additions & 0 deletions src/Juvix/Compiler/Asm/Transformation/StackUsage.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ computeFunctionStackUsage tab fi = do
return
( max (si ^. stackInfoValueStackHeight) (max (maximum (map (maximum . map fst) cs)) (maybe 0 (maximum . map fst) md)),
max (si ^. stackInfoTempStackHeight) (max (maximum (map (maximum . map snd) cs)) (maybe 0 (maximum . map snd) md))
),
_recurseSave = \si _ b ->
return
( max (si ^. stackInfoValueStackHeight) (maximum (map fst b)),
max (si ^. stackInfoTempStackHeight) (maximum (map snd b))
)
}

Expand Down
3 changes: 2 additions & 1 deletion src/Juvix/Compiler/Asm/Transformation/Validate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ validateCode tab fi code = do
{ _recursorInfoTable = tab,
_recurseInstr = \_ _ -> return (),
_recurseBranch = \_ _ _ _ -> return (),
_recurseCase = \_ _ _ _ -> return ()
_recurseCase = \_ _ _ _ -> return (),
_recurseSave = \_ _ _ -> return ()
}

validateFunction :: (Member (Error AsmError) r) => InfoTable -> FunctionInfo -> Sem r FunctionInfo
Expand Down
Loading