Skip to content

Commit

Permalink
The assert builtin (#3014)
Browse files Browse the repository at this point in the history
* Requires #3015
  • Loading branch information
lukaszcz authored Sep 12, 2024
1 parent 8e20463 commit 26ea94b
Show file tree
Hide file tree
Showing 62 changed files with 309 additions and 28 deletions.
21 changes: 8 additions & 13 deletions examples/milestone/Bank/Bank.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ calculateInterest : Nat -> Nat -> Nat -> Nat
incrAmount (a : Nat) : Nat := div (a * rate) 10000;
in iterate (min 100 periods) incrAmount amount;

--- Asserts some ;Bool; condition.
assert : {A : Type} -> Bool -> A -> A
| c a := ite c a (failwith "assertion failed");

--- Returns a new ;Token;. Arguments are:
---
--- `owner`: The address of the account to issue the token to
Expand All @@ -82,7 +78,7 @@ assert : {A : Type} -> Bool -> A -> A
---
--- `caller`: Who is creating the transaction. It must be the bank.
issue : Address -> Address -> Nat -> Token
| caller owner amount := assert (caller == bankAddress) (mkToken owner 0 amount);
| caller owner amount := assert (caller == bankAddress) >-> mkToken owner 0 amount;

--- Deposits some amount of money into the bank.
deposit (bal : Balances) (token : Token) (amount : Nat) : Token :=
Expand All @@ -102,11 +98,10 @@ withdraw
(rate : Nat)
(periods : Nat)
: Token :=
assert
(caller == bankAddress)
(let
hash : Field := hashAddress recipient;
total : Nat := calculateInterest amount rate periods;
token : Token := mkToken recipient 0 total;
bal' : Balances := decrement hash amount bal;
in runOnChain (commitBalances bal') token);
assert (caller == bankAddress)
>-> let
hash : Field := hashAddress recipient;
total : Nat := calculateInterest amount rate periods;
token : Token := mkToken recipient 0 total;
bal' : Balances := decrement hash amount bal;
in runOnChain (commitBalances bal') token;
2 changes: 1 addition & 1 deletion juvix-stdlib
1 change: 1 addition & 0 deletions runtime/c/src/juvix/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
#define JUVIX_INT_TO_UINT8(var0, var1) \
(var0 = make_smallint((word_t)((uint8_t)(get_unboxed_int(var1) & 0xFF))))

#define JUVIX_ASSERT(val) (assert(is_true(val)))
#define JUVIX_TRACE(val) (io_trace(val))
#define JUVIX_DUMP (stacktrace_dump())
#define JUVIX_FAILURE(val) \
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Asm/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ recurse' sig = go True
throw $
AsmError loc "popping empty value stack"
return (popValueStack 1 mem)
Assert ->
return mem
Trace ->
return mem
Dump ->
Expand Down Expand Up @@ -412,6 +414,8 @@ recurseS' sig = go True
return (stackInfoPushValueStack 1 si)
Pop -> do
return (stackInfoPopValueStack 1 si)
Assert ->
return si
Trace ->
return si
Dump ->
Expand Down
5 changes: 5 additions & 0 deletions src/Juvix/Compiler/Asm/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ runCodeR infoTable funInfo = goCode (funInfo ^. functionCode) >> popLastValueSta
goCode cont
Pop ->
popValueStack >> goCode cont
Assert -> do
v <- topValueStack
unless (v == ValBool True) $
runtimeError "assertion failed"
goCode cont
Trace -> do
v <- topValueStack
logMessage (printValue infoTable v)
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Asm/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ data Instruction
Push Value
| -- | Pop the stack. JVA opcode: 'pop'.
Pop
| -- | Assert a boolean on top of the stack. Does not pop the stack. JVA
-- opcode: 'assert'.
Assert
| -- | Print a debug log of the object on top of the stack. Does not pop the
-- stack. JVA opcode: 'trace'.
Trace
Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Asm/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ instance PrettyCode Instruction where
Cairo op -> Tree.ppCode op
Push val -> (primitive Str.instrPush <+>) <$> ppCode val
Pop -> return $ primitive Str.instrPop
Assert -> return $ primitive Str.instrAssert
Trace -> return $ primitive Str.instrTrace
Dump -> return $ primitive Str.instrDump
Failure -> return $ primitive Str.instrFailure
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Asm/Translation/FromSource.hs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ command = do
mkInstr' loc . Push <$> value
"pop" ->
return $ mkInstr' loc Pop
"assert" ->
return $ mkInstr' loc Assert
"trace" ->
return $ mkInstr' loc Trace
"dump" ->
Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Asm/Translation/FromTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ genCode fi =
genUnOp :: Tree.UnaryOpcode -> Command
genUnOp op = case op of
Tree.PrimUnop op' -> mkUnop op'
Tree.OpAssert -> mkInstr Assert
Tree.OpTrace -> mkInstr Trace
Tree.OpFail -> mkInstr Failure

Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Backend/C/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ fromRegInstr bNoStack info = \case
unsupported "Cairo builtin"
Reg.Assign Reg.InstrAssign {..} ->
return $ stmtsAssign (fromVarRef _instrAssignResult) (fromValue _instrAssignValue)
Reg.Assert Reg.InstrAssert {..} ->
return [StatementExpr $ macroCall "JUVIX_ASSERT" [fromValue _instrAssertValue]]
Reg.Trace Reg.InstrTrace {..} ->
return [StatementExpr $ macroCall "JUVIX_TRACE" [fromValue _instrTraceValue]]
Reg.Dump ->
Expand Down
9 changes: 9 additions & 0 deletions src/Juvix/Compiler/Backend/Cairo/Translation/FromCasm.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ fromCasm instrs0 =
Casm.Return -> goReturn
Casm.Alloc x -> goAlloc x
Casm.Hint x -> goHint x
Casm.Assert x -> goAssert x
Casm.Trace {} -> []
Casm.Label {} -> []
Casm.Nop -> []
Expand Down Expand Up @@ -230,6 +231,14 @@ fromCasm instrs0 =
. set instrApUpdate ApUpdateAdd
$ defaultInstruction

goAssert :: Casm.InstrAssert -> [Element]
goAssert Casm.InstrAssert {..} =
toElems
. updateOps False (Casm.Val (Casm.Imm 0))
. updateDst _instrAssertValue
. set instrOpcode AssertEq
$ defaultInstruction

goHint :: Casm.Hint -> [Element]
goHint = \case
Casm.HintInput var -> [ElementHint (Hint ("Input(" <> var <> ")"))]
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Backend/Rust/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ fromRegInstr info = \case
unsupported "Cairo builtin"
Reg.Assign Reg.InstrAssign {..} ->
stmtsAssign (mkVarRef _instrAssignResult) (fromValue _instrAssignValue)
Reg.Assert {} ->
unsupported "assert"
Reg.Trace {} ->
unsupported "trace"
Reg.Dump ->
Expand Down
24 changes: 24 additions & 0 deletions src/Juvix/Compiler/Builtins/Assert.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module Juvix.Compiler.Builtins.Assert where

import Juvix.Compiler.Internal.Builtins
import Juvix.Compiler.Internal.Extra
import Juvix.Prelude

checkAssert :: (Members '[Reader BuiltinsTable, Error ScoperError, NameIdGen] r) => FunctionDef -> Sem r ()
checkAssert f = do
bool_ <- getBuiltinNameScoper (getLoc f) BuiltinBool
let assert_ = f ^. funDefName
l = getLoc f
varx <- freshVar l "x"
let x = toExpression varx
assertClauses :: [(Expression, Expression)]
assertClauses = [(assert_ @@ x, x)]
checkBuiltinFunctionInfo
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinAssert,
_funInfoSignature = bool_ --> bool_,
_funInfoClauses = assertClauses,
_funInfoFreeVars = [varx],
_funInfoFreeTypeVars = []
}
8 changes: 8 additions & 0 deletions src/Juvix/Compiler/Casm/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ hRunCode hout inputInfo (LabelInfo labelInfo) instrs0 = runST goCode
Call x -> goCall x pc ap fp mem
Return -> goReturn pc ap fp mem
Alloc x -> goAlloc x pc ap fp mem
Assert x -> goAssert x pc ap fp mem
Trace x -> goTrace x pc ap fp mem
Hint x -> goHint x pc ap fp mem
Label {} -> go (pc + 1) ap fp mem
Expand Down Expand Up @@ -244,6 +245,13 @@ hRunCode hout inputInfo (LabelInfo labelInfo) instrs0 = runST goCode
v <- readRValue ap fp mem _instrAllocSize
go (pc + 1) (ap + fromInteger (fieldToInteger v)) fp mem

goAssert :: InstrAssert -> Address -> Address -> Address -> Memory s -> ST s FField
goAssert InstrAssert {..} pc ap fp mem = do
v <- readMemRef ap fp mem _instrAssertValue
when (fieldToInteger v /= 0) $
throwRunError "assertion failed"
go (pc + 1) ap fp mem

goTrace :: InstrTrace -> Address -> Address -> Address -> Memory s -> ST s FField
goTrace InstrTrace {..} pc ap fp mem = do
v <- readRValue ap fp mem _instrTraceValue
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Casm/Keywords.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Juvix.Data.Keyword.All
kwAbs,
kwAp,
kwApPlusPlus,
kwAssert,
kwCall,
kwColon,
kwDiv,
Expand Down Expand Up @@ -45,6 +46,7 @@ allKeywords =
kwAbs,
kwAp,
kwApPlusPlus,
kwAssert,
kwCall,
kwColon,
kwDiv,
Expand Down
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Casm/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ data Instruction
| Call InstrCall
| Return
| Alloc InstrAlloc
| Assert InstrAssert
| Trace InstrTrace
| Hint Hint
| Label LabelRef
Expand Down Expand Up @@ -132,6 +133,10 @@ newtype InstrAlloc = InstrAlloc
{ _instrAllocSize :: RValue
}

newtype InstrAssert = InstrAssert
{ _instrAssertValue :: MemRef
}

newtype InstrTrace = InstrTrace
{ _instrTraceValue :: RValue
}
Expand All @@ -148,4 +153,5 @@ makeLenses ''InstrJump
makeLenses ''InstrJumpIf
makeLenses ''InstrCall
makeLenses ''InstrAlloc
makeLenses ''InstrAssert
makeLenses ''InstrTrace
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Casm/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ instance PrettyCode InstrAlloc where
s <- ppCode _instrAllocSize
return $ Str.ap <+> Str.plusequal <+> s

instance PrettyCode InstrAssert where
ppCode InstrAssert {..} = do
v <- ppCode _instrAssertValue
return $ Str.assert_ <+> v

instance PrettyCode InstrTrace where
ppCode InstrTrace {..} = do
v <- ppCode _instrTraceValue
Expand All @@ -185,6 +190,7 @@ instance PrettyCode Instruction where
Call x -> ppCode x
Return -> return Str.ret
Alloc x -> ppCode x
Assert x -> ppCode x
Trace x -> ppCode x
Hint x -> ppCode x
Label x -> (<> colon) <$> ppCode x
Expand Down
13 changes: 13 additions & 0 deletions src/Juvix/Compiler/Casm/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
Reg.Assign x -> goAssign x
Reg.Alloc x -> goAlloc x
Reg.AllocClosure x -> goAllocClosure x
Reg.Assert x -> goAssert x
Reg.Trace x -> goTrace x
Reg.Dump -> unsupported "dump"
Reg.Failure x -> goFail x
Expand Down Expand Up @@ -512,6 +513,18 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
storedArgsNum = length _instrAllocClosureArgs
leftArgsNum = _instrAllocClosureExpectedArgsNum - storedArgsNum

goAssert :: Reg.InstrAssert -> Sem r ()
goAssert Reg.InstrAssert {..} = do
v <- goValue _instrAssertValue
case v of
Imm c
| c == 0 -> return ()
| otherwise ->
output' 0 $ mkAssign (MemRef Ap 0) (Binop $ BinopValue FieldAdd (MemRef Ap 0) (Imm 1))
Ref r ->
output' 0 $ Assert (InstrAssert r)
Lab {} -> unsupported "assert label"

goTrace :: Reg.InstrTrace -> Sem r ()
goTrace Reg.InstrTrace {..} = do
v <- mkRValue _instrTraceValue
Expand Down
7 changes: 7 additions & 0 deletions src/Juvix/Compiler/Casm/Translation/FromSource.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ instruction =
<|> parseJump
<|> parseCall
<|> parseReturn
<|> parseAssert
<|> parseTrace
<|> parseAssign

Expand Down Expand Up @@ -249,6 +250,12 @@ parseReturn = do
kw kwRet
return Return

parseAssert :: ParsecS r Instruction
parseAssert = do
kw kwAssert
r <- parseMemRef
return $ Assert $ InstrAssert {_instrAssertValue = r}

parseTrace :: (Member LabelInfoBuilder r) => ParsecS r Instruction
parseTrace = do
kw kwTrace
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Casm/Validate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ validate labi instrs = mapM_ go instrs
Call x -> goCall x
Return -> return ()
Alloc x -> goAlloc x
Assert x -> goAssert x
Trace x -> goTrace x
Hint {} -> return ()
Label {} -> return ()
Expand Down Expand Up @@ -66,3 +67,6 @@ validate labi instrs = mapM_ go instrs

goTrace :: InstrTrace -> Either CasmError ()
goTrace InstrTrace {..} = goRValue _instrTraceValue

goAssert :: InstrAssert -> Either CasmError ()
goAssert InstrAssert {} = return ()
8 changes: 7 additions & 1 deletion src/Juvix/Compiler/Concrete/Data/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ instance Serialize BuiltinConstructor
instance NFData BuiltinConstructor

data BuiltinFunction
= BuiltinNatPlus
= BuiltinAssert
| BuiltinNatPlus
| BuiltinNatSub
| BuiltinNatMul
| BuiltinNatUDiv
Expand Down Expand Up @@ -163,6 +164,7 @@ instance NFData BuiltinFunction

instance Pretty BuiltinFunction where
pretty = \case
BuiltinAssert -> Str.assert
BuiltinNatPlus -> Str.natPlus
BuiltinNatSub -> Str.natSub
BuiltinNatMul -> Str.natMul
Expand Down Expand Up @@ -368,6 +370,7 @@ isNatBuiltin = \case
BuiltinNatLt -> True
BuiltinNatEq -> True
--
BuiltinAssert -> False
BuiltinBoolIf -> False
BuiltinBoolOr -> False
BuiltinBoolAnd -> False
Expand Down Expand Up @@ -403,6 +406,7 @@ isIntBuiltin = \case
BuiltinIntLe -> True
BuiltinIntLt -> True
--
BuiltinAssert -> False
BuiltinNatPlus -> False
BuiltinNatSub -> False
BuiltinNatMul -> False
Expand All @@ -425,6 +429,7 @@ isCastBuiltin = \case
BuiltinFromNat -> True
BuiltinFromInt -> True
--
BuiltinAssert -> False
BuiltinIntEq -> False
BuiltinIntPlus -> False
BuiltinIntSubNat -> False
Expand Down Expand Up @@ -496,6 +501,7 @@ isIgnoredBuiltin f
-- Monad
BuiltinMonadBind -> False
-- Ignored
BuiltinAssert -> True
BuiltinBoolIf -> True
BuiltinBoolOr -> True
BuiltinBoolAnd -> True
Expand Down
Loading

0 comments on commit 26ea94b

Please sign in to comment.