@@ -3,40 +3,44 @@ package smt
3
3
import (
4
4
"bytes"
5
5
"context"
6
+ "errors"
6
7
"math/big"
7
8
8
9
"github.com/holiman/uint256"
9
10
libcommon "github.com/ledgerwatch/erigon-lib/common"
11
+ "github.com/ledgerwatch/erigon/core/state"
10
12
"github.com/ledgerwatch/erigon/core/types/accounts"
11
13
"github.com/ledgerwatch/erigon/smt/pkg/utils"
12
14
"github.com/ledgerwatch/erigon/zkevm/log"
13
15
)
14
16
17
+ var _ state.StateReader = (* SMT )(nil )
18
+
15
19
// ReadAccountData reads account data from the SMT
16
20
func (s * SMT ) ReadAccountData (address libcommon.Address ) (* accounts.Account , error ) {
17
- account := accounts.Account {}
18
-
19
21
balance , err := s .GetAccountBalance (address )
20
22
if err != nil {
21
23
return nil , err
22
24
}
23
- account .Balance = * balance
24
25
25
26
nonce , err := s .GetAccountNonce (address )
26
27
if err != nil {
27
28
return nil , err
28
29
}
29
- account .Nonce = nonce .Uint64 ()
30
30
31
31
codeHash , err := s .GetAccountCodeHash (address )
32
32
if err != nil {
33
33
return nil , err
34
34
}
35
- account .CodeHash = codeHash
36
35
37
- account .Root = libcommon.Hash {}
36
+ account := & accounts.Account {
37
+ Balance : * balance ,
38
+ Nonce : nonce .Uint64 (),
39
+ CodeHash : codeHash ,
40
+ Root : libcommon.Hash {},
41
+ }
38
42
39
- return & account , nil
43
+ return account , nil
40
44
}
41
45
42
46
// ReadAccountStorage reads account storage from the SMT (not implemented for SMT)
@@ -60,59 +64,69 @@ func (s *SMT) ReadAccountCode(address libcommon.Address, incarnation uint64, cod
60
64
}
61
65
62
66
// ReadAccountCodeSize reads account code size from the SMT
63
- func (s * SMT ) ReadAccountCodeSize (address libcommon.Address , incarnation uint64 , codeHash libcommon.Hash ) (int , error ) {
67
+ func (s * SMT ) ReadAccountCodeSize (address libcommon.Address , _ uint64 , _ libcommon.Hash ) (int , error ) {
64
68
valueInBytes , err := s .getValue (utils .SC_LENGTH , address , nil )
65
69
if err != nil {
66
70
return 0 , err
67
71
}
68
72
69
73
sizeBig := big .NewInt (0 ).SetBytes (valueInBytes )
70
74
71
- return int (sizeBig .Int64 ()), nil
75
+ if ! sizeBig .IsInt64 () {
76
+ err = errors .New ("code size value is too large to fit into an int" )
77
+ return 0 , err
78
+ }
79
+
80
+ sizeInt64 := sizeBig .Int64 ()
81
+ if sizeInt64 > int64 (^ uint (0 )>> 1 ) {
82
+ err = errors .New ("code size value overflows int" )
83
+ log .Error ("failed to get account code size" , "error" , err )
84
+ return 0 , err
85
+ }
86
+
87
+ return int (sizeInt64 ), nil
72
88
}
73
89
74
90
// ReadAccountIncarnation reads account incarnation from the SMT (not implemented for SMT)
75
- func (s * SMT ) ReadAccountIncarnation (address libcommon.Address ) (uint64 , error ) {
76
- return 0 , nil
91
+ func (s * SMT ) ReadAccountIncarnation (_ libcommon.Address ) (uint64 , error ) {
92
+ return 0 , errors . New ( "ReadAccountIncarnation not implemented for SMT" )
77
93
}
78
94
79
95
// GetAccountBalance returns the balance of an account from the SMT
80
96
func (s * SMT ) GetAccountBalance (address libcommon.Address ) (* uint256.Int , error ) {
81
- balance := uint256 .NewInt (0 )
82
-
83
97
valueInBytes , err := s .getValue (utils .KEY_BALANCE , address , nil )
84
98
if err != nil {
85
- log .Error ("error getting balance" , "error" , err )
99
+ log .Error ("failed to get balance" , "error" , err )
86
100
return nil , err
87
101
}
88
- balance .SetBytes (valueInBytes )
102
+
103
+ balance := uint256 .NewInt (0 ).SetBytes (valueInBytes )
89
104
90
105
return balance , nil
91
106
}
92
107
93
108
// GetAccountNonce returns the nonce of an account from the SMT
94
109
func (s * SMT ) GetAccountNonce (address libcommon.Address ) (* uint256.Int , error ) {
95
- nonce := uint256 .NewInt (0 )
96
-
97
110
valueInBytes , err := s .getValue (utils .KEY_NONCE , address , nil )
98
111
if err != nil {
99
- log .Error ("error getting nonce" , "error" , err )
112
+ log .Error ("failed to get nonce" , "error" , err )
100
113
return nil , err
101
114
}
102
- nonce .SetBytes (valueInBytes )
115
+
116
+ nonce := uint256 .NewInt (0 ).SetBytes (valueInBytes )
103
117
104
118
return nonce , nil
105
119
}
106
120
107
121
// GetAccountCodeHash returns the code hash of an account from the SMT
108
122
func (s * SMT ) GetAccountCodeHash (address libcommon.Address ) (libcommon.Hash , error ) {
109
- codeHash := libcommon.Hash {}
110
-
111
123
valueInBytes , err := s .getValue (utils .SC_CODE , address , nil )
112
124
if err != nil {
113
- log .Error ("error getting codehash " , "error" , err )
125
+ log .Error ("failed to get code hash " , "error" , err )
114
126
return libcommon.Hash {}, err
115
127
}
128
+
129
+ codeHash := libcommon.Hash {}
116
130
codeHash .SetBytes (valueInBytes )
117
131
118
132
return codeHash , nil
0 commit comments