Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The assert builtin #3014

Merged
merged 7 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions examples/milestone/Bank/Bank.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ calculateInterest : Nat -> Nat -> Nat -> Nat
incrAmount (a : Nat) : Nat := div (a * rate) 10000;
in iterate (min 100 periods) incrAmount amount;

--- Asserts some ;Bool; condition.
assert : {A : Type} -> Bool -> A -> A
| c a := ite c a (failwith "assertion failed");

--- Returns a new ;Token;. Arguments are:
---
--- `owner`: The address of the account to issue the token to
Expand All @@ -82,7 +78,7 @@ assert : {A : Type} -> Bool -> A -> A
---
--- `caller`: Who is creating the transaction. It must be the bank.
issue : Address -> Address -> Nat -> Token
| caller owner amount := assert (caller == bankAddress) (mkToken owner 0 amount);
| caller owner amount := assert (caller == bankAddress) >-> mkToken owner 0 amount;

--- Deposits some amount of money into the bank.
deposit (bal : Balances) (token : Token) (amount : Nat) : Token :=
Expand All @@ -102,11 +98,10 @@ withdraw
(rate : Nat)
(periods : Nat)
: Token :=
assert
(caller == bankAddress)
(let
hash : Field := hashAddress recipient;
total : Nat := calculateInterest amount rate periods;
token : Token := mkToken recipient 0 total;
bal' : Balances := decrement hash amount bal;
in runOnChain (commitBalances bal') token);
assert (caller == bankAddress)
>-> let
hash : Field := hashAddress recipient;
total : Nat := calculateInterest amount rate periods;
token : Token := mkToken recipient 0 total;
bal' : Balances := decrement hash amount bal;
in runOnChain (commitBalances bal') token;
2 changes: 1 addition & 1 deletion juvix-stdlib
1 change: 1 addition & 0 deletions runtime/c/src/juvix/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
#define JUVIX_INT_TO_UINT8(var0, var1) \
(var0 = make_smallint((word_t)((uint8_t)(get_unboxed_int(var1) & 0xFF))))

#define JUVIX_ASSERT(val) (assert(is_true(val)))
#define JUVIX_TRACE(val) (io_trace(val))
#define JUVIX_DUMP (stacktrace_dump())
#define JUVIX_FAILURE(val) \
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Asm/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ recurse' sig = go True
throw $
AsmError loc "popping empty value stack"
return (popValueStack 1 mem)
Assert ->
return mem
Trace ->
return mem
Dump ->
Expand Down Expand Up @@ -412,6 +414,8 @@ recurseS' sig = go True
return (stackInfoPushValueStack 1 si)
Pop -> do
return (stackInfoPopValueStack 1 si)
Assert ->
return si
Trace ->
return si
Dump ->
Expand Down
5 changes: 5 additions & 0 deletions src/Juvix/Compiler/Asm/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ runCodeR infoTable funInfo = goCode (funInfo ^. functionCode) >> popLastValueSta
goCode cont
Pop ->
popValueStack >> goCode cont
Assert -> do
v <- topValueStack
unless (v == ValBool True) $
runtimeError "assertion failed"
goCode cont
Trace -> do
v <- topValueStack
logMessage (printValue infoTable v)
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Asm/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ data Instruction
Push Value
| -- | Pop the stack. JVA opcode: 'pop'.
Pop
| -- | Assert a boolean on top of the stack. Does not pop the stack. JVA
-- opcode: 'assert'.
Assert
| -- | Print a debug log of the object on top of the stack. Does not pop the
-- stack. JVA opcode: 'trace'.
Trace
Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Asm/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ instance PrettyCode Instruction where
Cairo op -> Tree.ppCode op
Push val -> (primitive Str.instrPush <+>) <$> ppCode val
Pop -> return $ primitive Str.instrPop
Assert -> return $ primitive Str.instrAssert
Trace -> return $ primitive Str.instrTrace
Dump -> return $ primitive Str.instrDump
Failure -> return $ primitive Str.instrFailure
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Asm/Translation/FromSource.hs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ command = do
mkInstr' loc . Push <$> value
"pop" ->
return $ mkInstr' loc Pop
"assert" ->
return $ mkInstr' loc Assert
"trace" ->
return $ mkInstr' loc Trace
"dump" ->
Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Asm/Translation/FromTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ genCode fi =
genUnOp :: Tree.UnaryOpcode -> Command
genUnOp op = case op of
Tree.PrimUnop op' -> mkUnop op'
Tree.OpAssert -> mkInstr Assert
Tree.OpTrace -> mkInstr Trace
Tree.OpFail -> mkInstr Failure

Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Backend/C/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ fromRegInstr bNoStack info = \case
unsupported "Cairo builtin"
Reg.Assign Reg.InstrAssign {..} ->
return $ stmtsAssign (fromVarRef _instrAssignResult) (fromValue _instrAssignValue)
Reg.Assert Reg.InstrAssert {..} ->
return [StatementExpr $ macroCall "JUVIX_ASSERT" [fromValue _instrAssertValue]]
Reg.Trace Reg.InstrTrace {..} ->
return [StatementExpr $ macroCall "JUVIX_TRACE" [fromValue _instrTraceValue]]
Reg.Dump ->
Expand Down
9 changes: 9 additions & 0 deletions src/Juvix/Compiler/Backend/Cairo/Translation/FromCasm.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ fromCasm instrs0 =
Casm.Return -> goReturn
Casm.Alloc x -> goAlloc x
Casm.Hint x -> goHint x
Casm.Assert x -> goAssert x
Casm.Trace {} -> []
Casm.Label {} -> []
Casm.Nop -> []
Expand Down Expand Up @@ -230,6 +231,14 @@ fromCasm instrs0 =
. set instrApUpdate ApUpdateAdd
$ defaultInstruction

goAssert :: Casm.InstrAssert -> [Element]
goAssert Casm.InstrAssert {..} =
toElems
. updateOps False (Casm.Val (Casm.Imm 0))
. updateDst _instrAssertValue
. set instrOpcode AssertEq
$ defaultInstruction

goHint :: Casm.Hint -> [Element]
goHint = \case
Casm.HintInput var -> [ElementHint (Hint ("Input(" <> var <> ")"))]
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Backend/Rust/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ fromRegInstr info = \case
unsupported "Cairo builtin"
Reg.Assign Reg.InstrAssign {..} ->
stmtsAssign (mkVarRef _instrAssignResult) (fromValue _instrAssignValue)
Reg.Assert {} ->
unsupported "assert"
Reg.Trace {} ->
unsupported "trace"
Reg.Dump ->
Expand Down
24 changes: 24 additions & 0 deletions src/Juvix/Compiler/Builtins/Assert.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module Juvix.Compiler.Builtins.Assert where

import Juvix.Compiler.Internal.Builtins
import Juvix.Compiler.Internal.Extra
import Juvix.Prelude

checkAssert :: (Members '[Reader BuiltinsTable, Error ScoperError, NameIdGen] r) => FunctionDef -> Sem r ()
checkAssert f = do
bool_ <- getBuiltinNameScoper (getLoc f) BuiltinBool
let assert_ = f ^. funDefName
l = getLoc f
varx <- freshVar l "x"
let x = toExpression varx
assertClauses :: [(Expression, Expression)]
assertClauses = [(assert_ @@ x, x)]
checkBuiltinFunctionInfo
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinAssert,
_funInfoSignature = bool_ --> bool_,
_funInfoClauses = assertClauses,
_funInfoFreeVars = [varx],
_funInfoFreeTypeVars = []
}
8 changes: 8 additions & 0 deletions src/Juvix/Compiler/Casm/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ hRunCode hout inputInfo (LabelInfo labelInfo) instrs0 = runST goCode
Call x -> goCall x pc ap fp mem
Return -> goReturn pc ap fp mem
Alloc x -> goAlloc x pc ap fp mem
Assert x -> goAssert x pc ap fp mem
Trace x -> goTrace x pc ap fp mem
Hint x -> goHint x pc ap fp mem
Label {} -> go (pc + 1) ap fp mem
Expand Down Expand Up @@ -244,6 +245,13 @@ hRunCode hout inputInfo (LabelInfo labelInfo) instrs0 = runST goCode
v <- readRValue ap fp mem _instrAllocSize
go (pc + 1) (ap + fromInteger (fieldToInteger v)) fp mem

goAssert :: InstrAssert -> Address -> Address -> Address -> Memory s -> ST s FField
goAssert InstrAssert {..} pc ap fp mem = do
v <- readMemRef ap fp mem _instrAssertValue
when (fieldToInteger v /= 0) $
throwRunError "assertion failed"
go (pc + 1) ap fp mem

goTrace :: InstrTrace -> Address -> Address -> Address -> Memory s -> ST s FField
goTrace InstrTrace {..} pc ap fp mem = do
v <- readRValue ap fp mem _instrTraceValue
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Casm/Keywords.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Juvix.Data.Keyword.All
kwAbs,
kwAp,
kwApPlusPlus,
kwAssert,
kwCall,
kwColon,
kwDiv,
Expand Down Expand Up @@ -45,6 +46,7 @@ allKeywords =
kwAbs,
kwAp,
kwApPlusPlus,
kwAssert,
kwCall,
kwColon,
kwDiv,
Expand Down
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Casm/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ data Instruction
| Call InstrCall
| Return
| Alloc InstrAlloc
| Assert InstrAssert
| Trace InstrTrace
| Hint Hint
| Label LabelRef
Expand Down Expand Up @@ -132,6 +133,10 @@ newtype InstrAlloc = InstrAlloc
{ _instrAllocSize :: RValue
}

newtype InstrAssert = InstrAssert
{ _instrAssertValue :: MemRef
}

newtype InstrTrace = InstrTrace
{ _instrTraceValue :: RValue
}
Expand All @@ -148,4 +153,5 @@ makeLenses ''InstrJump
makeLenses ''InstrJumpIf
makeLenses ''InstrCall
makeLenses ''InstrAlloc
makeLenses ''InstrAssert
makeLenses ''InstrTrace
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Casm/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ instance PrettyCode InstrAlloc where
s <- ppCode _instrAllocSize
return $ Str.ap <+> Str.plusequal <+> s

instance PrettyCode InstrAssert where
ppCode InstrAssert {..} = do
v <- ppCode _instrAssertValue
return $ Str.assert_ <+> v

instance PrettyCode InstrTrace where
ppCode InstrTrace {..} = do
v <- ppCode _instrTraceValue
Expand All @@ -185,6 +190,7 @@ instance PrettyCode Instruction where
Call x -> ppCode x
Return -> return Str.ret
Alloc x -> ppCode x
Assert x -> ppCode x
Trace x -> ppCode x
Hint x -> ppCode x
Label x -> (<> colon) <$> ppCode x
Expand Down
13 changes: 13 additions & 0 deletions src/Juvix/Compiler/Casm/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
Reg.Assign x -> goAssign x
Reg.Alloc x -> goAlloc x
Reg.AllocClosure x -> goAllocClosure x
Reg.Assert x -> goAssert x
Reg.Trace x -> goTrace x
Reg.Dump -> unsupported "dump"
Reg.Failure x -> goFail x
Expand Down Expand Up @@ -512,6 +513,18 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
storedArgsNum = length _instrAllocClosureArgs
leftArgsNum = _instrAllocClosureExpectedArgsNum - storedArgsNum

goAssert :: Reg.InstrAssert -> Sem r ()
goAssert Reg.InstrAssert {..} = do
v <- goValue _instrAssertValue
case v of
Imm c
| c == 0 -> return ()
| otherwise ->
output' 0 $ mkAssign (MemRef Ap 0) (Binop $ BinopValue FieldAdd (MemRef Ap 0) (Imm 1))
Ref r ->
output' 0 $ Assert (InstrAssert r)
Lab {} -> unsupported "assert label"

goTrace :: Reg.InstrTrace -> Sem r ()
goTrace Reg.InstrTrace {..} = do
v <- mkRValue _instrTraceValue
Expand Down
7 changes: 7 additions & 0 deletions src/Juvix/Compiler/Casm/Translation/FromSource.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ instruction =
<|> parseJump
<|> parseCall
<|> parseReturn
<|> parseAssert
<|> parseTrace
<|> parseAssign

Expand Down Expand Up @@ -249,6 +250,12 @@ parseReturn = do
kw kwRet
return Return

parseAssert :: ParsecS r Instruction
parseAssert = do
kw kwAssert
r <- parseMemRef
return $ Assert $ InstrAssert {_instrAssertValue = r}

parseTrace :: (Member LabelInfoBuilder r) => ParsecS r Instruction
parseTrace = do
kw kwTrace
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Casm/Validate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ validate labi instrs = mapM_ go instrs
Call x -> goCall x
Return -> return ()
Alloc x -> goAlloc x
Assert x -> goAssert x
Trace x -> goTrace x
Hint {} -> return ()
Label {} -> return ()
Expand Down Expand Up @@ -66,3 +67,6 @@ validate labi instrs = mapM_ go instrs

goTrace :: InstrTrace -> Either CasmError ()
goTrace InstrTrace {..} = goRValue _instrTraceValue

goAssert :: InstrAssert -> Either CasmError ()
goAssert InstrAssert {} = return ()
8 changes: 7 additions & 1 deletion src/Juvix/Compiler/Concrete/Data/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ instance Serialize BuiltinConstructor
instance NFData BuiltinConstructor

data BuiltinFunction
= BuiltinNatPlus
= BuiltinAssert
| BuiltinNatPlus
| BuiltinNatSub
| BuiltinNatMul
| BuiltinNatUDiv
Expand Down Expand Up @@ -163,6 +164,7 @@ instance NFData BuiltinFunction

instance Pretty BuiltinFunction where
pretty = \case
BuiltinAssert -> Str.assert
BuiltinNatPlus -> Str.natPlus
BuiltinNatSub -> Str.natSub
BuiltinNatMul -> Str.natMul
Expand Down Expand Up @@ -368,6 +370,7 @@ isNatBuiltin = \case
BuiltinNatLt -> True
BuiltinNatEq -> True
--
BuiltinAssert -> False
BuiltinBoolIf -> False
BuiltinBoolOr -> False
BuiltinBoolAnd -> False
Expand Down Expand Up @@ -403,6 +406,7 @@ isIntBuiltin = \case
BuiltinIntLe -> True
BuiltinIntLt -> True
--
BuiltinAssert -> False
BuiltinNatPlus -> False
BuiltinNatSub -> False
BuiltinNatMul -> False
Expand All @@ -425,6 +429,7 @@ isCastBuiltin = \case
BuiltinFromNat -> True
BuiltinFromInt -> True
--
BuiltinAssert -> False
BuiltinIntEq -> False
BuiltinIntPlus -> False
BuiltinIntSubNat -> False
Expand Down Expand Up @@ -496,6 +501,7 @@ isIgnoredBuiltin f
-- Monad
BuiltinMonadBind -> False
-- Ignored
BuiltinAssert -> True
BuiltinBoolIf -> True
BuiltinBoolOr -> True
BuiltinBoolAnd -> True
Expand Down
Loading
Loading