Skip to content

Commit

Permalink
Add support for unsigned 8-bit integer type Byte (#2918)
Browse files Browse the repository at this point in the history
This PR adds `Byte` as a builtin with builtin functions for equality,
`byte-from-nat` and `byte-to-nat`. The standard library is updated to
include this definition with instances for `FromNatural`, `Show` and
`Eq` traits.

The `FromNatural` trait means that you can assign `Byte` values using
non-negative numeric literals.


You can use byte literals in jvc files by adding the u8 suffix to a
numeric value. For example, 1u8 represents a byte literal.

Arithmetic is not supported as the intention is for this type to be used
to construct ByteArrays of data where isn't not appropriate to modify
using arithmetic operations. We may add a separate `UInt8` type in the
future which supports arithmetic.

The Byte is supported in the native, rust and Anoma backend. Byte is not
supported in the Cairo backend because `byte-from-nat` cannot be
defined.

The primitive builtin ops for `Byte` are called `OpUInt8ToInt` and
`OpUInt8FromInt`, named because these ops work on integers and in future
we may reuse these for a separate unsigned 8-bit integer type that
supports arithmetic.

Part of:

* #2865
  • Loading branch information
paulcadman authored Aug 2, 2024
1 parent d3f57a6 commit e2fe830
Show file tree
Hide file tree
Showing 49 changed files with 405 additions and 7 deletions.
2 changes: 1 addition & 1 deletion juvix-stdlib
4 changes: 4 additions & 0 deletions runtime/c/src/juvix/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@

#define JUVIX_ASSIGN(var0, val) (var0 = val)

#define JUVIX_UINT8_TO_INT(var0, var1) (var0 = var1)
#define JUVIX_INT_TO_UINT8(var0, var1) \
(var0 = make_smallint((word_t)((uint8_t)(get_unboxed_int(var1) & 0xFF))))

#define JUVIX_TRACE(val) (io_trace(val))
#define JUVIX_DUMP (stacktrace_dump())
#define JUVIX_FAILURE(val) \
Expand Down
17 changes: 17 additions & 0 deletions runtime/rust/juvix/src/integer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ pub fn smallint_le(x: Word, y: Word) -> Word {
bool_to_word(smallint_value(x) <= smallint_value(y))
}


pub fn uint8_to_int(x : Word) -> Word {
x
}

pub fn int_to_uint8(x : Word) -> Word {
make_smallint(smallint_value(x).rem_euclid(256))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -65,4 +74,12 @@ mod tests {
assert_eq!(make_smallint(x), y);
}
}

#[test]
fn test_int_to_uint8() {
assert_eq!(smallint_value(int_to_uint8(make_smallint(-1))), 255);
assert_eq!(smallint_value(int_to_uint8(make_smallint(255))), 255);
assert_eq!(smallint_value(int_to_uint8(make_smallint(-256))), 0);
assert_eq!(smallint_value(int_to_uint8(make_smallint(256))), 0);
}
}
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Asm/Extra/Memory.hs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ getConstantType = \case
ConstString {} -> TyString
ConstUnit -> TyUnit
ConstVoid -> TyVoid
ConstUInt8 {} -> mkTypeUInt8

getValueType' :: (Member (Error AsmError) r) => Maybe Location -> InfoTable -> Memory -> Value -> Sem r Type
getValueType' loc tab mem = \case
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Asm/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ recurse' sig = go True
checkUnop mkTypeInteger TyField
OpFieldToInt ->
checkUnop TyField mkTypeInteger
OpUInt8ToInt ->
checkUnop mkTypeUInt8 mkTypeInteger
OpIntToUInt8 ->
checkUnop mkTypeInteger mkTypeUInt8
where
checkUnop :: Type -> Type -> Sem r Memory
checkUnop ty1 ty2 = do
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Backend/C/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ fromRegInstr bNoStack info = \case
Reg.OpArgsNum -> "JUVIX_ARGS_NUM"
Reg.OpFieldToInt -> unsupported "field type"
Reg.OpIntToField -> unsupported "field type"
Reg.OpUInt8ToInt -> "JUVIX_UINT8_TO_INT"
Reg.OpIntToUInt8 -> "JUVIX_INT_TO_UINT8"

fromVarRef :: Reg.VarRef -> Expression
fromVarRef Reg.VarRef {..} =
Expand Down Expand Up @@ -347,6 +349,7 @@ fromRegInstr bNoStack info = \case
Reg.ConstString x -> macroCall "GET_CONST_CSTRING" [integer (getStringId info x)]
Reg.ConstUnit -> macroVar "OBJ_UNIT"
Reg.ConstVoid -> macroVar "OBJ_VOID"
Reg.ConstUInt8 x -> macroCall "make_smallint" [integer x]

fromPrealloc :: Reg.InstrPrealloc -> Statement
fromPrealloc Reg.InstrPrealloc {..} =
Expand Down
9 changes: 9 additions & 0 deletions src/Juvix/Compiler/Backend/Rust/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ fromRegInstr info = \case
(mkCall "mem.get_closure_largs" [fromValue _instrUnopArg])
Reg.OpFieldToInt -> unsupported "field type"
Reg.OpIntToField -> unsupported "field type"
Reg.OpUInt8ToInt ->
stmtAssign
(mkVarRef _instrUnopResult)
(mkCall "uint8_to_int" [fromValue _instrUnopArg])
Reg.OpIntToUInt8 ->
stmtAssign
(mkVarRef _instrUnopResult)
(mkCall "int_to_uint8" [fromValue _instrUnopArg])

mkVarRef :: Reg.VarRef -> Text
mkVarRef Reg.VarRef {..} = case _varRefGroup of
Expand Down Expand Up @@ -242,6 +250,7 @@ fromRegInstr info = \case
Reg.ConstString {} -> unsupported "strings"
Reg.ConstUnit -> mkVar "OBJ_UNIT"
Reg.ConstVoid -> mkVar "OBJ_VOID"
Reg.ConstUInt8 x -> mkCall "make_smallint" [mkInteger x]

fromAlloc :: Reg.InstrAlloc -> [Statement]
fromAlloc Reg.InstrAlloc {..} =
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ module Juvix.Compiler.Builtins
module Juvix.Compiler.Builtins.Control,
module Juvix.Compiler.Builtins.Anoma,
module Juvix.Compiler.Builtins.Cairo,
module Juvix.Compiler.Builtins.Byte,
)
where

import Juvix.Compiler.Builtins.Anoma
import Juvix.Compiler.Builtins.Bool
import Juvix.Compiler.Builtins.Byte
import Juvix.Compiler.Builtins.Cairo
import Juvix.Compiler.Builtins.Control
import Juvix.Compiler.Builtins.Debug
Expand Down
32 changes: 32 additions & 0 deletions src/Juvix/Compiler/Builtins/Byte.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module Juvix.Compiler.Builtins.Byte where

import Juvix.Compiler.Builtins.Effect
import Juvix.Compiler.Internal.Extra
import Juvix.Prelude

registerByte :: (Member Builtins r) => AxiomDef -> Sem r ()
registerByte d = do
unless (isSmallUniverse' (d ^. axiomType)) (error "Byte should be in the small universe")
registerBuiltin BuiltinByte (d ^. axiomName)

registerByteEq :: (Member Builtins r) => AxiomDef -> Sem r ()
registerByteEq f = do
byte_ <- getBuiltinName (getLoc f) BuiltinByte
bool_ <- getBuiltinName (getLoc f) BuiltinBool
unless (f ^. axiomType === (byte_ --> byte_ --> bool_)) (error "Byte equality has the wrong type signature")
registerBuiltin BuiltinByteEq (f ^. axiomName)

registerByteFromNat :: (Member Builtins r) => AxiomDef -> Sem r ()
registerByteFromNat d = do
let l = getLoc d
byte_ <- getBuiltinName l BuiltinByte
nat <- getBuiltinName l BuiltinNat
unless (d ^. axiomType === (nat --> byte_)) (error "byte-from-nat has the wrong type signature")
registerBuiltin BuiltinByteFromNat (d ^. axiomName)

registerByteToNat :: (Member Builtins r) => AxiomDef -> Sem r ()
registerByteToNat f = do
byte_ <- getBuiltinName (getLoc f) BuiltinByte
nat_ <- getBuiltinName (getLoc f) BuiltinNat
unless (f ^. axiomType === (byte_ --> nat_)) (error "byte-to-nat has the wrong type signature")
registerBuiltin BuiltinByteToNat (f ^. axiomName)
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Casm/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
Reg.ConstUnit -> 0
Reg.ConstVoid -> 0
Reg.ConstString {} -> unsupported "strings"
Reg.ConstUInt8 {} -> unsupported "uint8"

mkLoad :: Reg.ConstrField -> Sem r RValue
mkLoad Reg.ConstrField {..} = do
Expand Down Expand Up @@ -458,6 +459,8 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
Reg.OpFieldToInt -> goAssignValue _instrUnopResult _instrUnopArg
Reg.OpIntToField -> goAssignValue _instrUnopResult _instrUnopArg
Reg.OpArgsNum -> goUnop' goOpArgsNum _instrUnopResult _instrUnopArg
Reg.OpUInt8ToInt -> unsupported "OpUInt8ToInt"
Reg.OpIntToUInt8 -> unsupported "OpIntToUInt8"

goCairo :: Reg.InstrCairo -> Sem r ()
goCairo Reg.InstrCairo {..} = case _instrCairoOpcode of
Expand Down
12 changes: 12 additions & 0 deletions src/Juvix/Compiler/Concrete/Data/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ data BuiltinAxiom
| BuiltinPoseidon
| BuiltinEcOp
| BuiltinRandomEcPoint
| BuiltinByte
| BuiltinByteEq
| BuiltinByteToNat
| BuiltinByteFromNat
deriving stock (Show, Eq, Ord, Enum, Bounded, Generic, Data)

instance HasNameKind BuiltinAxiom where
Expand Down Expand Up @@ -255,6 +259,10 @@ instance HasNameKind BuiltinAxiom where
BuiltinPoseidon -> KNameFunction
BuiltinEcOp -> KNameFunction
BuiltinRandomEcPoint -> KNameFunction
BuiltinByte -> KNameInductive
BuiltinByteEq -> KNameFunction
BuiltinByteToNat -> KNameFunction
BuiltinByteFromNat -> KNameFunction

getNameKindPretty :: BuiltinAxiom -> NameKind
getNameKindPretty = getNameKind
Expand Down Expand Up @@ -300,6 +308,10 @@ instance Pretty BuiltinAxiom where
BuiltinPoseidon -> Str.cairoPoseidon
BuiltinEcOp -> Str.cairoEcOp
BuiltinRandomEcPoint -> Str.cairoRandomEcPoint
BuiltinByte -> Str.byte_
BuiltinByteEq -> Str.byteEq
BuiltinByteToNat -> Str.byteToNat
BuiltinByteFromNat -> Str.byteFromNat

data BuiltinType
= BuiltinTypeInductive BuiltinInductive
Expand Down
34 changes: 34 additions & 0 deletions src/Juvix/Compiler/Core/Evaluator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ geval opts herr tab env0 = eval' env0
OpPoseidonHash -> poseidonHashOp
OpEc -> ecOp
OpRandomEcPoint -> randomEcPointOp
OpUInt8ToInt -> uint8ToIntOp
OpUInt8FromInt -> uint8FromIntOp
where
err :: Text -> a
err msg = evalError msg n
Expand Down Expand Up @@ -509,6 +511,28 @@ geval opts herr tab env0 = eval' env0
!publicKey = publicKeyFromInteger publicKeyInt
in nodeFromBool (E.dverify publicKey message sig)
{-# INLINE verifyDetached #-}

uint8FromIntOp :: [Node] -> Node
uint8FromIntOp =
unary $ \node ->
let !v = eval' env node
in nodeFromUInt8
. fromIntegral
. fromMaybe (evalError "expected integer" v)
. integerFromNode
$ v
{-# INLINE uint8FromIntOp #-}

uint8ToIntOp :: [Node] -> Node
uint8ToIntOp =
unary $ \node ->
let !v = eval' env node
in nodeFromInteger
. toInteger
. fromMaybe (evalError "expected uint8" v)
. uint8FromNode
$ v
{-# INLINE uint8ToIntOp #-}
{-# INLINE applyBuiltin #-}

-- secretKey, publicKey are not encoded with their length as
Expand All @@ -530,6 +554,10 @@ geval opts herr tab env0 = eval' env0
nodeFromField !fld = mkConstant' (ConstField fld)
{-# INLINE nodeFromField #-}

nodeFromUInt8 :: Word8 -> Node
nodeFromUInt8 !w = mkConstant' (ConstUInt8 w)
{-# INLINE nodeFromUInt8 #-}

nodeFromBool :: Bool -> Node
nodeFromBool b = mkConstr' (BuiltinTag tag) []
where
Expand Down Expand Up @@ -577,6 +605,12 @@ geval opts herr tab env0 = eval' env0
_ -> Nothing
{-# INLINE fieldFromNode #-}

uint8FromNode :: Node -> Maybe Word8
uint8FromNode = \case
NCst (Constant _ (ConstUInt8 i)) -> Just i
_ -> Nothing
{-# INLINE uint8FromNode #-}

printNode :: Node -> Text
printNode = \case
NCst (Constant _ (ConstString s)) -> s
Expand Down
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ mkTypeField i = mkTypePrim i PrimField
mkTypeField' :: Type
mkTypeField' = mkTypeField Info.empty

mkTypeUInt8 :: Info -> Type
mkTypeUInt8 i = mkTypePrim i primitiveUInt8

mkTypeUInt8' :: Type
mkTypeUInt8' = mkTypeUInt8 Info.empty

mkDynamic :: Info -> Type
mkDynamic i = NDyn (DynamicTy i)

Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,8 @@ builtinOpArgTypes = \case
OpPoseidonHash -> [mkDynamic']
OpEc -> [mkDynamic', mkTypeField', mkDynamic']
OpRandomEcPoint -> []
OpUInt8ToInt -> [mkTypeUInt8']
OpUInt8FromInt -> [mkTypeInteger']

translateCase :: (Node -> Node -> Node -> a) -> a -> Case -> a
translateCase translateIfFun dflt Case {..} = case _caseBranches of
Expand Down
9 changes: 9 additions & 0 deletions src/Juvix/Compiler/Core/Language/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ data BuiltinOp
| OpPoseidonHash
| OpEc
| OpRandomEcPoint
| OpUInt8ToInt
| OpUInt8FromInt
deriving stock (Eq, Generic)

instance Serialize BuiltinOp
Expand Down Expand Up @@ -90,6 +92,8 @@ builtinOpArgsNum = \case
OpPoseidonHash -> 1
OpEc -> 3
OpRandomEcPoint -> 0
OpUInt8ToInt -> 1
OpUInt8FromInt -> 1

builtinConstrArgsNum :: BuiltinDataTag -> Int
builtinConstrArgsNum = \case
Expand Down Expand Up @@ -133,6 +137,8 @@ builtinIsFoldable = \case
OpPoseidonHash -> False
OpEc -> False
OpRandomEcPoint -> False
OpUInt8ToInt -> True
OpUInt8FromInt -> True

builtinIsCairo :: BuiltinOp -> Bool
builtinIsCairo op = op `elem` builtinsCairo
Expand All @@ -156,3 +162,6 @@ builtinsAnoma =
OpAnomaVerifyWithMessage,
OpAnomaSignDetached
]

builtinsUInt8 :: [BuiltinOp]
builtinsUInt8 = [OpUInt8FromInt, OpUInt8ToInt]
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Core/Language/Nodes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ data ConstantValue
= ConstInteger !Integer
| ConstField !FField
| ConstString !Text
| ConstUInt8 !Word8
deriving stock (Eq, Generic)

-- | Info about a single binder. Associated with Lambda, Pi, Let, Case or Match.
Expand Down
9 changes: 9 additions & 0 deletions src/Juvix/Compiler/Core/Language/Primitives.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ data Primitive
| PrimField
deriving stock (Eq, Generic)

primitiveUInt8 :: Primitive
primitiveUInt8 =
PrimInteger
( PrimIntegerInfo
{ _infoMinValue = Just 0,
_infoMaxValue = Just 255
}
)

-- | Info about a type represented as an integer.
data PrimIntegerInfo = PrimIntegerInfo
{ _infoMinValue :: Maybe Integer,
Expand Down
12 changes: 12 additions & 0 deletions src/Juvix/Compiler/Core/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ instance PrettyCode BuiltinOp where
OpPoseidonHash -> return primPoseidonHash
OpEc -> return primEc
OpRandomEcPoint -> return primRandomEcPoint
OpUInt8ToInt -> return primUInt8ToInt
OpUInt8FromInt -> return primFieldFromInt

instance PrettyCode BuiltinDataTag where
ppCode = \case
Expand Down Expand Up @@ -107,13 +109,17 @@ instance PrettyCode ConstantValue where
return $ annotate AnnLiteralInteger (pretty int)
ConstField fld ->
return $ annotate AnnLiteralInteger (pretty fld)
ConstUInt8 i ->
return $ annotate AnnLiteralInteger (pretty i)
ConstString txt ->
return $ annotate AnnLiteralString (pretty (show txt :: String))

instance PrettyCode (Constant' i) where
ppCode Constant {..} = case _constantValue of
ConstField fld ->
return $ annotate AnnLiteralInteger (pretty fld <> "F")
ConstUInt8 i ->
return $ annotate AnnLiteralInteger (pretty i <> "u8")
_ -> ppCode _constantValue

instance (PrettyCode a, HasAtomicity a) => PrettyCode (App' i a) where
Expand Down Expand Up @@ -732,6 +738,12 @@ primFieldDiv = primitive Str.fdiv
primFieldFromInt :: Doc Ann
primFieldFromInt = primitive Str.itof

primUInt8ToInt :: Doc Ann
primUInt8ToInt = primitive Str.u8toi

primUInt8FromInt :: Doc Ann
primUInt8FromInt = primitive Str.itou8

primFieldToInt :: Doc Ann
primFieldToInt = primitive Str.ftoi

Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Transformation/Check/Cairo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ checkCairo md = do
checkMainType
checkNoAxioms md
mapAllNodesM checkNoIO md
mapAllNodesM (checkBuiltins' builtinsString [PrimString]) md
mapAllNodesM (checkBuiltins' (builtinsString ++ builtinsUInt8) [PrimString, primitiveUInt8]) md
where
checkMainType :: Sem r ()
checkMainType =
Expand Down
Loading

0 comments on commit e2fe830

Please sign in to comment.