From fd3959da3807f746333bfec4ea91b699f215cbda Mon Sep 17 00:00:00 2001 From: rian Date: Wed, 15 Jan 2025 11:59:26 +0200 Subject: [PATCH] comments: db_utils.go, inline, felt.Zero --- mempool/db_utils.go | 63 +++++++++++++++++++++++++ mempool/mempool.go | 102 +++++++++------------------------------- mempool/mempool_test.go | 4 +- 3 files changed, 86 insertions(+), 83 deletions(-) create mode 100644 mempool/db_utils.go diff --git a/mempool/db_utils.go b/mempool/db_utils.go new file mode 100644 index 0000000000..1638910464 --- /dev/null +++ b/mempool/db_utils.go @@ -0,0 +1,63 @@ +package mempool + +import ( + "errors" + "math/big" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/encoder" +) + +func headValue(txn db.Transaction, head *felt.Felt) error { + return txn.Get(db.MempoolHead.Key(), func(b []byte) error { + head.SetBytes(b) + return nil + }) +} + +func tailValue(txn db.Transaction, tail *felt.Felt) error { + return txn.Get(db.MempoolTail.Key(), func(b []byte) error { + tail.SetBytes(b) + return nil + }) +} + +func updateHead(txn db.Transaction, head *felt.Felt) error { + return txn.Set(db.MempoolHead.Key(), head.Marshal()) +} + +func updateTail(txn db.Transaction, tail *felt.Felt) error { + return txn.Set(db.MempoolTail.Key(), tail.Marshal()) +} + +func readDBElem(txn db.Transaction, itemKey *felt.Felt) (dbPoolTxn, error) { + var item dbPoolTxn + keyBytes := itemKey.Bytes() + err := txn.Get(db.MempoolNode.Key(keyBytes[:]), func(b []byte) error { + return encoder.Unmarshal(b, &item) + }) + return item, err +} + +func setDBElem(txn db.Transaction, item *dbPoolTxn) error { + itemBytes, err := encoder.Marshal(item) + if err != nil { + return err + } + keyBytes := item.Txn.Transaction.Hash().Bytes() + return txn.Set(db.MempoolNode.Key(keyBytes[:]), itemBytes) +} + +func lenDB(txn db.Transaction) (int, error) { + var l int + err := txn.Get(db.MempoolLength.Key(), func(b []byte) error { + l = int(new(big.Int).SetBytes(b).Int64()) + return nil + }) + + if err != nil && errors.Is(err, db.ErrKeyNotFound) { + return 0, nil + } + return l, err +} diff --git a/mempool/mempool.go b/mempool/mempool.go index 5a615b8270..4ebb80f7c8 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -9,7 +9,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/encoder" "github.com/NethermindEth/juno/utils" ) @@ -123,8 +122,8 @@ func (p *Pool) dbWriter() { // LoadFromDB restores the in-memory transaction pool from the database func (p *Pool) LoadFromDB() error { return p.db.View(func(txn db.Transaction) error { - headValue := new(felt.Felt) - err := p.headHash(txn, headValue) + headVal := new(felt.Felt) + err := headValue(txn, headVal) if err != nil { if errors.Is(err, db.ErrKeyNotFound) { return nil @@ -132,9 +131,9 @@ func (p *Pool) LoadFromDB() error { return err } // loop through the persistent pool and push nodes to the in-memory pool - currentHash := headValue + currentHash := headVal for currentHash != nil { - curDBElem, err := p.readDBElem(txn, currentHash) + curDBElem, err := readDBElem(txn, currentHash) if err != nil { return err } @@ -142,7 +141,7 @@ func (p *Pool) LoadFromDB() error { Txn: curDBElem.Txn, } if curDBElem.NextHash != nil { - nextDBTxn, err := p.readDBElem(txn, curDBElem.NextHash) + nextDBTxn, err := readDBElem(txn, curDBElem.NextHash) if err != nil { return err } @@ -160,41 +159,41 @@ func (p *Pool) LoadFromDB() error { // writeToDB adds the transaction to the persistent pool db func (p *Pool) writeToDB(userTxn *BroadcastedTransaction) error { return p.db.Update(func(dbTxn db.Transaction) error { - tailValue := new(felt.Felt) - if err := p.tailValue(dbTxn, tailValue); err != nil { + tailVal := new(felt.Felt) + if err := tailValue(dbTxn, tailVal); err != nil { if !errors.Is(err, db.ErrKeyNotFound) { return err } - tailValue = nil + tailVal = nil } - if err := p.setDBElem(dbTxn, &dbPoolTxn{Txn: *userTxn}); err != nil { + if err := setDBElem(dbTxn, &dbPoolTxn{Txn: *userTxn}); err != nil { return err } - if tailValue != nil { + if tailVal != nil { // Update old tail to point to the new item var oldTailElem dbPoolTxn - oldTailElem, err := p.readDBElem(dbTxn, tailValue) + oldTailElem, err := readDBElem(dbTxn, tailVal) if err != nil { return err } oldTailElem.NextHash = userTxn.Transaction.Hash() - if err = p.setDBElem(dbTxn, &oldTailElem); err != nil { + if err = setDBElem(dbTxn, &oldTailElem); err != nil { return err } } else { // Empty list, make new item both the head and the tail - if err := p.updateHead(dbTxn, userTxn.Transaction.Hash()); err != nil { + if err := updateHead(dbTxn, userTxn.Transaction.Hash()); err != nil { return err } } - if err := p.updateTail(dbTxn, userTxn.Transaction.Hash()); err != nil { + if err := updateTail(dbTxn, userTxn.Transaction.Hash()); err != nil { return err } - pLen, err := p.lenDB(dbTxn) + pLen, err := lenDB(dbTxn) if err != nil { return err } - return p.updateLen(dbTxn, pLen+1) + return dbTxn.Set(db.MempoolLength.Key(), new(big.Int).SetInt64(int64(pLen+1)).Bytes()) }) } @@ -285,78 +284,19 @@ func (p *Pool) Len() int { return p.memTxnList.len } +func (p *Pool) Wait() <-chan struct{} { + return p.txPushed +} + // Len returns the number of transactions in the persistent pool func (p *Pool) LenDB() (int, error) { - p.wg.Add(1) - defer p.wg.Done() txn, err := p.db.NewTransaction(false) if err != nil { return 0, err } - lenDB, err := p.lenDB(txn) + lenDB, err := lenDB(txn) if err != nil { return 0, err } return lenDB, txn.Discard() } - -func (p *Pool) lenDB(txn db.Transaction) (int, error) { - var l int - err := txn.Get(db.MempoolLength.Key(), func(b []byte) error { - l = int(new(big.Int).SetBytes(b).Int64()) - return nil - }) - - if err != nil && errors.Is(err, db.ErrKeyNotFound) { - return 0, nil - } - return l, err -} - -func (p *Pool) updateLen(txn db.Transaction, l int) error { - return txn.Set(db.MempoolLength.Key(), new(big.Int).SetInt64(int64(l)).Bytes()) -} - -func (p *Pool) Wait() <-chan struct{} { - return p.txPushed -} - -func (p *Pool) headHash(txn db.Transaction, head *felt.Felt) error { - return txn.Get(db.MempoolHead.Key(), func(b []byte) error { - head.SetBytes(b) - return nil - }) -} - -func (p *Pool) updateHead(txn db.Transaction, head *felt.Felt) error { - return txn.Set(db.MempoolHead.Key(), head.Marshal()) -} - -func (p *Pool) tailValue(txn db.Transaction, tail *felt.Felt) error { - return txn.Get(db.MempoolTail.Key(), func(b []byte) error { - tail.SetBytes(b) - return nil - }) -} - -func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error { - return txn.Set(db.MempoolTail.Key(), tail.Marshal()) -} - -func (p *Pool) readDBElem(txn db.Transaction, itemKey *felt.Felt) (dbPoolTxn, error) { - var item dbPoolTxn - keyBytes := itemKey.Bytes() - err := txn.Get(db.MempoolNode.Key(keyBytes[:]), func(b []byte) error { - return encoder.Unmarshal(b, &item) - }) - return item, err -} - -func (p *Pool) setDBElem(txn db.Transaction, item *dbPoolTxn) error { - itemBytes, err := encoder.Marshal(item) - if err != nil { - return err - } - keyBytes := item.Txn.Transaction.Hash().Bytes() - return txn.Set(db.MempoolNode.Key(keyBytes[:]), itemBytes) -} diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 39a71cf644..57ede7cd32 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -57,7 +57,7 @@ func TestMempool(t *testing.T) { // push multiple to empty (1,2,3) for i := uint64(1); i < 4; i++ { //nolint:dupl senderAddress := new(felt.Felt).SetUint64(i) - state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) + state.EXPECT().ContractNonce(senderAddress).Return(&felt.Zero, nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), @@ -79,7 +79,7 @@ func TestMempool(t *testing.T) { // push multiple to non empty (push 4,5. now have 3,4,5) for i := uint64(4); i < 6; i++ { senderAddress := new(felt.Felt).SetUint64(i) - state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) + state.EXPECT().ContractNonce(senderAddress).Return(&felt.Zero, nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i),