Skip to content

Commit c2d0368

Browse files
feat: simplifications in SMT state reader
1 parent 15834d8 commit c2d0368

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

smt/pkg/smt/smt_state_reader.go

+36-22
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,44 @@ package smt
33
import (
44
"bytes"
55
"context"
6+
"errors"
67
"math/big"
78

89
"github.com/holiman/uint256"
910
libcommon "github.com/ledgerwatch/erigon-lib/common"
11+
"github.com/ledgerwatch/erigon/core/state"
1012
"github.com/ledgerwatch/erigon/core/types/accounts"
1113
"github.com/ledgerwatch/erigon/smt/pkg/utils"
1214
"github.com/ledgerwatch/erigon/zkevm/log"
1315
)
1416

17+
var _ state.StateReader = (*SMT)(nil)
18+
1519
// ReadAccountData reads account data from the SMT
1620
func (s *SMT) ReadAccountData(address libcommon.Address) (*accounts.Account, error) {
17-
account := accounts.Account{}
18-
1921
balance, err := s.GetAccountBalance(address)
2022
if err != nil {
2123
return nil, err
2224
}
23-
account.Balance = *balance
2425

2526
nonce, err := s.GetAccountNonce(address)
2627
if err != nil {
2728
return nil, err
2829
}
29-
account.Nonce = nonce.Uint64()
3030

3131
codeHash, err := s.GetAccountCodeHash(address)
3232
if err != nil {
3333
return nil, err
3434
}
35-
account.CodeHash = codeHash
3635

37-
account.Root = libcommon.Hash{}
36+
account := &accounts.Account{
37+
Balance: *balance,
38+
Nonce: nonce.Uint64(),
39+
CodeHash: codeHash,
40+
Root: libcommon.Hash{},
41+
}
3842

39-
return &account, nil
43+
return account, nil
4044
}
4145

4246
// ReadAccountStorage reads account storage from the SMT (not implemented for SMT)
@@ -60,59 +64,69 @@ func (s *SMT) ReadAccountCode(address libcommon.Address, incarnation uint64, cod
6064
}
6165

6266
// ReadAccountCodeSize reads account code size from the SMT
63-
func (s *SMT) ReadAccountCodeSize(address libcommon.Address, incarnation uint64, codeHash libcommon.Hash) (int, error) {
67+
func (s *SMT) ReadAccountCodeSize(address libcommon.Address, _ uint64, _ libcommon.Hash) (int, error) {
6468
valueInBytes, err := s.getValue(utils.SC_LENGTH, address, nil)
6569
if err != nil {
6670
return 0, err
6771
}
6872

6973
sizeBig := big.NewInt(0).SetBytes(valueInBytes)
7074

71-
return int(sizeBig.Int64()), nil
75+
if !sizeBig.IsInt64() {
76+
err = errors.New("code size value is too large to fit into an int")
77+
return 0, err
78+
}
79+
80+
sizeInt64 := sizeBig.Int64()
81+
if sizeInt64 > int64(^uint(0)>>1) {
82+
err = errors.New("code size value overflows int")
83+
log.Error("failed to get account code size", "error", err)
84+
return 0, err
85+
}
86+
87+
return int(sizeInt64), nil
7288
}
7389

7490
// ReadAccountIncarnation reads account incarnation from the SMT (not implemented for SMT)
75-
func (s *SMT) ReadAccountIncarnation(address libcommon.Address) (uint64, error) {
76-
return 0, nil
91+
func (s *SMT) ReadAccountIncarnation(_ libcommon.Address) (uint64, error) {
92+
return 0, errors.New("ReadAccountIncarnation not implemented for SMT")
7793
}
7894

7995
// GetAccountBalance returns the balance of an account from the SMT
8096
func (s *SMT) GetAccountBalance(address libcommon.Address) (*uint256.Int, error) {
81-
balance := uint256.NewInt(0)
82-
8397
valueInBytes, err := s.getValue(utils.KEY_BALANCE, address, nil)
8498
if err != nil {
85-
log.Error("error getting balance", "error", err)
99+
log.Error("failed to get balance", "error", err)
86100
return nil, err
87101
}
88-
balance.SetBytes(valueInBytes)
102+
103+
balance := uint256.NewInt(0).SetBytes(valueInBytes)
89104

90105
return balance, nil
91106
}
92107

93108
// GetAccountNonce returns the nonce of an account from the SMT
94109
func (s *SMT) GetAccountNonce(address libcommon.Address) (*uint256.Int, error) {
95-
nonce := uint256.NewInt(0)
96-
97110
valueInBytes, err := s.getValue(utils.KEY_NONCE, address, nil)
98111
if err != nil {
99-
log.Error("error getting nonce", "error", err)
112+
log.Error("failed to get nonce", "error", err)
100113
return nil, err
101114
}
102-
nonce.SetBytes(valueInBytes)
115+
116+
nonce := uint256.NewInt(0).SetBytes(valueInBytes)
103117

104118
return nonce, nil
105119
}
106120

107121
// GetAccountCodeHash returns the code hash of an account from the SMT
108122
func (s *SMT) GetAccountCodeHash(address libcommon.Address) (libcommon.Hash, error) {
109-
codeHash := libcommon.Hash{}
110-
111123
valueInBytes, err := s.getValue(utils.SC_CODE, address, nil)
112124
if err != nil {
113-
log.Error("error getting codehash", "error", err)
125+
log.Error("failed to get code hash", "error", err)
114126
return libcommon.Hash{}, err
115127
}
128+
129+
codeHash := libcommon.Hash{}
116130
codeHash.SetBytes(valueInBytes)
117131

118132
return codeHash, nil

0 commit comments

Comments
 (0)