diff --git a/jsonrpc/endpoints_zkevm.go b/jsonrpc/endpoints_zkevm.go index f159885cb3..800ea5c947 100644 --- a/jsonrpc/endpoints_zkevm.go +++ b/jsonrpc/endpoints_zkevm.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "math/big" + "slices" "time" "github.com/0xPolygonHermez/zkevm-node/hex" @@ -638,3 +639,69 @@ func (z *ZKEVMEndpoints) GetLatestGlobalExitRoot() (interface{}, types.Error) { return ger.String(), nil } + +// GetForkId returns the network's current fork ID +func (z *ZKEVMEndpoints) GetForkId() (interface{}, types.Error) { + ctx := context.Background() + forkID, err := z.state.GetCurrentForkID(ctx, nil) + if err != nil { + return "0x0", types.NewRPCError(types.DefaultErrorCode, "failed to get the current fork id from state") + } + + return hex.EncodeUint64(forkID), nil +} + +// GetForkById returns the network fork ID interval given the provided fork id +func (z *ZKEVMEndpoints) GetForkById(forkID types.ArgUint64) (interface{}, types.Error) { + ctx := context.Background() + forkIDInterval, err := z.state.GetForkByID(ctx, uint64(forkID), nil) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } else if err != nil { + return nil, types.NewRPCError(types.DefaultErrorCode, "failed to get the fork interval by id from state") + } + + res := types.NewForkIDInterval(*forkIDInterval) + return res, nil +} + +// GetForkIdByBatchNumber returns the fork ID given the provided batch number +func (z *ZKEVMEndpoints) GetForkIdByBatchNumber(batchNumber types.BatchNumber) (interface{}, types.Error) { + ctx := context.Background() + + numericBatchNumber, rpcErr := batchNumber.GetNumericBatchNumber(ctx, z.state, z.etherman, nil) + if rpcErr != nil { + return nil, rpcErr + } + + forkID := z.state.GetForkIDByBatchNumber(numericBatchNumber) + return hex.EncodeUint64(forkID), nil +} + +// GetForks returns the network fork ID intervals +func (z *ZKEVMEndpoints) GetForks() (interface{}, types.Error) { + ctx := context.Background() + forkIDIntervals, err := z.state.GetForkIDIntervals(ctx, nil) + if errors.Is(err, state.ErrStateNotSynchronized) { + return nil, nil + } else if err != nil { + return nil, types.NewRPCError(types.DefaultErrorCode, "failed to get the fork id intervals from state") + } + + res := make([]*types.ForkIDInterval, 0, len(forkIDIntervals)) + for _, forkIDInterval := range forkIDIntervals { + res = append(res, types.NewForkIDInterval(forkIDInterval)) + } + + slices.SortFunc(res, func(a *types.ForkIDInterval, b *types.ForkIDInterval) int { + if a.ForkId == b.ForkId { + return 0 + } else if a.ForkId > b.ForkId { + return 1 + } else { + return -1 + } + }) + + return res, nil +} diff --git a/jsonrpc/endpoints_zkevm.openrpc.json b/jsonrpc/endpoints_zkevm.openrpc.json index d795e0f1cb..eb71eb3248 100644 --- a/jsonrpc/endpoints_zkevm.openrpc.json +++ b/jsonrpc/endpoints_zkevm.openrpc.json @@ -471,6 +471,126 @@ "$ref": "#/components/schemas/Integer" } } + }, + { + "name": "zkevm_forkId", + "summary": "Returns the network's current fork ID.", + "params": [], + "result": { + "$ref": "#/components/contentDescriptors/ForkID" + }, + "examples": [ + { + "name": "example", + "description": "", + "params": [], + "result": { + "name": "exampleResult", + "description": "", + "value": "0x1" + } + } + ] + }, + { + "name": "zkevm_getForkById", + "summary": "returns the network fork ID interval given the provided fork id", + "params": [ + { + "$ref": "#/components/contentDescriptors/ForkID" + } + ], + "result": { + "$ref": "#/components/contentDescriptors/Fork" + }, + "examples": [ + { + "name": "example", + "description": "", + "params": [ + { + "name": "fork id", + "value": "0x1" + } + ], + "result": { + "name": "Fork", + "value": { + "forkId": "0x8", + "fromBatchNumber": "0x1", + "toBatchNumber": "0xffffffffffffffff", + "version": "", + "blockNumber": "0x88" + } + } + } + ] + }, + { + "name": "zkevm_getForkIdByBatchNumber", + "summary": "returns the fork ID given the provided batch number", + "params": [ + { + "$ref": "#/components/contentDescriptors/BatchNumber" + } + ], + "result": { + "$ref": "#/components/contentDescriptors/ForkID" + }, + "examples": [ + { + "name": "example", + "description": "", + "params": [], + "result": { + "name": "exampleResult", + "description": "", + "value": "0x1" + } + } + ] + }, + { + "name": "zkevm_getForks", + "summary": "returns the network fork ID interval given the provided fork id", + "params": [], + "result": { + "name": "result", + "schema": { + "description": "Array of forks", + "type": "array", + "items": { + "title": "fork", + "$ref": "#/components/contentDescriptors/Fork" + } + } + }, + "examples": [ + { + "name": "example", + "description": "", + "params": [], + "result": { + "name": "Fork", + "value": [ + { + "forkId": "0x8", + "fromBatchNumber": "0x1", + "toBatchNumber": "0xa", + "version": "", + "blockNumber": "0x88" + }, + { + "forkId": "0x9", + "fromBatchNumber": "0xb", + "toBatchNumber": "0xffffffffffffffff", + "version": "", + "blockNumber": "0x188" + } + ] + } + } + ] } ], "components": { @@ -520,6 +640,14 @@ "$ref": "#/components/schemas/Block" } }, + "Fork": { + "name": "fork", + "description": "fork", + "required": true, + "schema": { + "$ref": "#/components/schemas/Fork" + } + }, "Transaction": { "required": true, "name": "transaction", @@ -548,6 +676,13 @@ } ] } + }, + "ForkID": { + "name": "forkID", + "required": true, + "schema": { + "$ref": "#/components/schemas/ForkID" + } } }, "schemas": { @@ -1437,6 +1572,36 @@ "$ref": "#/components/schemas/Integer" } } + }, + "ForkID": { + "title": "forkID", + "type": "string", + "description": "The hex representation of the fork's id", + "$ref": "#/components/schemas/Integer" + }, + "Fork": { + "title": "Fork", + "type": "object", + "readOnly": true, + "properties": { + "forkId": { + "$ref": "#/components/schemas/ForkID" + }, + "fromBatchNumber": { + "$ref": "#/components/schemas/BatchNumber" + }, + "toBatchNumber": { + "$ref": "#/components/schemas/BatchNumber" + }, + "Version": { + "title": "batchNumberTag", + "type": "string", + "description": "fork version" + }, + "BlockNumber": { + "$ref": "#/components/schemas/BlockNumber" + } + } } } } diff --git a/jsonrpc/endpoints_zkevm_test.go b/jsonrpc/endpoints_zkevm_test.go index ff2761158f..dae375877f 100644 --- a/jsonrpc/endpoints_zkevm_test.go +++ b/jsonrpc/endpoints_zkevm_test.go @@ -2180,3 +2180,320 @@ func TestGetLatestGlobalExitRoot(t *testing.T) { }) } } + +func TestGetForkId(t *testing.T) { + s, m, _ := newSequencerMockedServer(t) + defer s.Stop() + + type testCase struct { + Name string + ExpectedResult uint64 + ExpectedError types.Error + SetupMocks func(m *mocksWrapper) + } + + testCases := []testCase{ + { + Name: "get fork id successfully", + ExpectedError: nil, + ExpectedResult: 10, + SetupMocks: func(m *mocksWrapper) { + m.State. + On("GetCurrentForkID", context.Background(), nil). + Return(uint64(10), nil). + Once() + }, + }, + { + Name: "failed to get fork id", + ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the current fork id from state"), + ExpectedResult: 0, + SetupMocks: func(m *mocksWrapper) { + m.State. + On("GetCurrentForkID", context.Background(), nil). + Return(uint64(0), errors.New("failed to get current fork id")). + Once() + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + tc := testCase + tc.SetupMocks(m) + + res, err := s.JSONRPCCall("zkevm_getForkId") + require.NoError(t, err) + + if res.Result != nil { + var result types.ArgUint64 + err = json.Unmarshal(res.Result, &result) + require.NoError(t, err) + assert.Equal(t, tc.ExpectedResult, uint64(result)) + } + + if res.Error != nil || tc.ExpectedError != nil { + assert.Equal(t, tc.ExpectedError.ErrorCode(), res.Error.Code) + assert.Equal(t, tc.ExpectedError.Error(), res.Error.Message) + } + }) + } +} + +func TestGetForkById(t *testing.T) { + s, m, _ := newSequencerMockedServer(t) + defer s.Stop() + + const forkID = uint64(3) + + type testCase struct { + Name string + ExpectedResult *types.ForkIDInterval + ExpectedError types.Error + SetupMocks func(m *mocksWrapper) + } + + testCases := []testCase{ + { + Name: "get fork by id successfully", + ExpectedError: nil, + ExpectedResult: &types.ForkIDInterval{ + ForkId: 3, + FromBatchNumber: 1, + ToBatchNumber: 2, + Version: "0.0.1", + BlockNumber: 10, + }, + SetupMocks: func(m *mocksWrapper) { + m.State. + On("GetForkByID", context.Background(), forkID, nil). + Return(&state.ForkIDInterval{ + ForkId: 3, + FromBatchNumber: 1, + ToBatchNumber: 2, + Version: "0.0.1", + BlockNumber: 10, + }, nil). + Once() + }, + }, + { + Name: "failed to get fork by id", + ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the fork interval by id from state"), + ExpectedResult: nil, + SetupMocks: func(m *mocksWrapper) { + m.State. + On("GetForkByID", context.Background(), forkID, nil). + Return(nil, errors.New("failed to get fork by id")). + Once() + }, + }, + { + Name: "fork by id not found", + ExpectedError: nil, + ExpectedResult: nil, + SetupMocks: func(m *mocksWrapper) { + m.State. + On("GetForkByID", context.Background(), forkID, nil). + Return(nil, state.ErrNotFound). + Once() + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + tc := testCase + tc.SetupMocks(m) + + res, err := s.JSONRPCCall("zkevm_getForkById", hex.EncodeUint64(forkID)) + require.NoError(t, err) + + if res.Result != nil || tc.ExpectedResult != nil { + var result *types.ForkIDInterval + err = json.Unmarshal(res.Result, &result) + require.NoError(t, err) + + if tc.ExpectedResult == nil { + assert.Nil(t, result) + } else { + assert.Equal(t, tc.ExpectedResult.ForkId, result.ForkId) + assert.Equal(t, tc.ExpectedResult.FromBatchNumber, result.FromBatchNumber) + assert.Equal(t, tc.ExpectedResult.ToBatchNumber, result.ToBatchNumber) + assert.Equal(t, tc.ExpectedResult.Version, result.Version) + assert.Equal(t, tc.ExpectedResult.BlockNumber, result.BlockNumber) + } + } + + if res.Error != nil || tc.ExpectedError != nil { + assert.Equal(t, tc.ExpectedError.ErrorCode(), res.Error.Code) + assert.Equal(t, tc.ExpectedError.Error(), res.Error.Message) + } + }) + } +} + +func TestGetForkIdByBatchNumber(t *testing.T) { + s, m, _ := newSequencerMockedServer(t) + defer s.Stop() + forkID := uint64(1) + batchNumber := uint64(2) + + type testCase struct { + Name string + ExpectedResult *uint64 + ExpectedError types.Error + SetupMocks func(m *mocksWrapper) + } + + testCases := []testCase{ + { + Name: "get fork id by batch number successfully", + ExpectedResult: &forkID, + SetupMocks: func(m *mocksWrapper) { + m.State. + On("GetForkIDByBatchNumber", batchNumber). + Return(forkID, nil). + Once() + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + tc := testCase + tc.SetupMocks(m) + + res, err := s.JSONRPCCall("zkevm_getForkIdByBatchNumber", hex.EncodeUint64(batchNumber)) + require.NoError(t, err) + + if tc.ExpectedResult != nil { + var result types.ArgUint64 + err = json.Unmarshal(res.Result, &result) + require.NoError(t, err) + assert.Equal(t, *tc.ExpectedResult, uint64(result)) + } else { + if res.Result == nil { + assert.Nil(t, res.Result) + } else { + var result *uint64 + err = json.Unmarshal(res.Result, &result) + require.NoError(t, err) + assert.Nil(t, result) + } + } + + if tc.ExpectedError != nil { + assert.Equal(t, tc.ExpectedError.ErrorCode(), res.Error.Code) + assert.Equal(t, tc.ExpectedError.Error(), res.Error.Message) + } else { + assert.Nil(t, res.Error) + } + }) + } +} + +func TestGetForks(t *testing.T) { + s, m, _ := newSequencerMockedServer(t) + defer s.Stop() + + type testCase struct { + Name string + ExpectedResult []types.ForkIDInterval + ExpectedError types.Error + SetupMocks func(m *mocksWrapper) + } + + testCases := []testCase{ + { + Name: "get forks successfully", + ExpectedError: nil, + ExpectedResult: []types.ForkIDInterval{ + { + ForkId: 1, + FromBatchNumber: 1, + ToBatchNumber: 2, + Version: "0.0.1", + BlockNumber: 5, + }, + { + ForkId: 2, + FromBatchNumber: 3, + ToBatchNumber: 4, + Version: "0.0.2", + BlockNumber: 10, + }, + }, + SetupMocks: func(m *mocksWrapper) { + m.State. + On("GetForkIDIntervals", context.Background(), nil). + Return([]state.ForkIDInterval{ + { + ForkId: 1, + FromBatchNumber: 1, + ToBatchNumber: 2, + Version: "0.0.1", + BlockNumber: 5, + }, + { + ForkId: 2, + FromBatchNumber: 3, + ToBatchNumber: 4, + Version: "0.0.2", + BlockNumber: 10, + }, + }, nil). + Once() + }, + }, + { + Name: "failed to get forks", + ExpectedError: types.NewRPCError(types.DefaultErrorCode, "failed to get the fork id intervals from state"), + ExpectedResult: nil, + SetupMocks: func(m *mocksWrapper) { + m.State. + On("GetForkIDIntervals", context.Background(), nil). + Return(nil, errors.New("failed to get fork id intervals")). + Once() + }, + }, + { + Name: "forks when state is not synchronized yet", + ExpectedError: nil, + ExpectedResult: nil, + SetupMocks: func(m *mocksWrapper) { + m.State. + On("GetForkIDIntervals", context.Background(), nil). + Return(nil, state.ErrStateNotSynchronized). + Once() + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + tc := testCase + tc.SetupMocks(m) + + res, err := s.JSONRPCCall("zkevm_getForks") + require.NoError(t, err) + + if res.Result != nil || tc.ExpectedResult != nil { + var result []types.ForkIDInterval + err = json.Unmarshal(res.Result, &result) + require.NoError(t, err) + + if tc.ExpectedResult == nil { + assert.Nil(t, result) + } else { + assert.ElementsMatch(t, tc.ExpectedResult, result) + } + } + + if res.Error != nil || tc.ExpectedError != nil { + assert.Equal(t, tc.ExpectedError.ErrorCode(), res.Error.Code) + assert.Equal(t, tc.ExpectedError.Error(), res.Error.Message) + } + }) + } +} diff --git a/jsonrpc/mocks/mock_state.go b/jsonrpc/mocks/mock_state.go index 36f552fe65..a9c2a761da 100644 --- a/jsonrpc/mocks/mock_state.go +++ b/jsonrpc/mocks/mock_state.go @@ -271,6 +271,34 @@ func (_m *StateMock) GetCode(ctx context.Context, address common.Address, root c return r0, r1 } +// GetCurrentForkID provides a mock function with given fields: ctx, dbTx +func (_m *StateMock) GetCurrentForkID(ctx context.Context, dbTx pgx.Tx) (uint64, error) { + ret := _m.Called(ctx, dbTx) + + if len(ret) == 0 { + panic("no return value specified for GetCurrentForkID") + } + + var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, pgx.Tx) (uint64, error)); ok { + return rf(ctx, dbTx) + } + if rf, ok := ret.Get(0).(func(context.Context, pgx.Tx) uint64); ok { + r0 = rf(ctx, dbTx) + } else { + r0 = ret.Get(0).(uint64) + } + + if rf, ok := ret.Get(1).(func(context.Context, pgx.Tx) error); ok { + r1 = rf(ctx, dbTx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetExitRootByGlobalExitRoot provides a mock function with given fields: ctx, ger, dbTx func (_m *StateMock) GetExitRootByGlobalExitRoot(ctx context.Context, ger common.Hash, dbTx pgx.Tx) (*state.GlobalExitRoot, error) { ret := _m.Called(ctx, ger, dbTx) @@ -301,6 +329,84 @@ func (_m *StateMock) GetExitRootByGlobalExitRoot(ctx context.Context, ger common return r0, r1 } +// GetForkByID provides a mock function with given fields: ctx, forkID, dbTx +func (_m *StateMock) GetForkByID(ctx context.Context, forkID uint64, dbTx pgx.Tx) (*state.ForkIDInterval, error) { + ret := _m.Called(ctx, forkID, dbTx) + + if len(ret) == 0 { + panic("no return value specified for GetForkByID") + } + + var r0 *state.ForkIDInterval + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uint64, pgx.Tx) (*state.ForkIDInterval, error)); ok { + return rf(ctx, forkID, dbTx) + } + if rf, ok := ret.Get(0).(func(context.Context, uint64, pgx.Tx) *state.ForkIDInterval); ok { + r0 = rf(ctx, forkID, dbTx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*state.ForkIDInterval) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, uint64, pgx.Tx) error); ok { + r1 = rf(ctx, forkID, dbTx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetForkIDByBatchNumber provides a mock function with given fields: batchNumber +func (_m *StateMock) GetForkIDByBatchNumber(batchNumber uint64) uint64 { + ret := _m.Called(batchNumber) + + if len(ret) == 0 { + panic("no return value specified for GetForkIDByBatchNumber") + } + + var r0 uint64 + if rf, ok := ret.Get(0).(func(uint64) uint64); ok { + r0 = rf(batchNumber) + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +// GetForkIDIntervals provides a mock function with given fields: ctx, dbTx +func (_m *StateMock) GetForkIDIntervals(ctx context.Context, dbTx pgx.Tx) ([]state.ForkIDInterval, error) { + ret := _m.Called(ctx, dbTx) + + if len(ret) == 0 { + panic("no return value specified for GetForkIDIntervals") + } + + var r0 []state.ForkIDInterval + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, pgx.Tx) ([]state.ForkIDInterval, error)); ok { + return rf(ctx, dbTx) + } + if rf, ok := ret.Get(0).(func(context.Context, pgx.Tx) []state.ForkIDInterval); ok { + r0 = rf(ctx, dbTx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]state.ForkIDInterval) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, pgx.Tx) error); ok { + r1 = rf(ctx, dbTx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetL2BlockByHash provides a mock function with given fields: ctx, hash, dbTx func (_m *StateMock) GetL2BlockByHash(ctx context.Context, hash common.Hash, dbTx pgx.Tx) (*state.L2Block, error) { ret := _m.Called(ctx, hash, dbTx) diff --git a/jsonrpc/types/interfaces.go b/jsonrpc/types/interfaces.go index c12040dea7..9cbd621d0d 100644 --- a/jsonrpc/types/interfaces.go +++ b/jsonrpc/types/interfaces.go @@ -76,6 +76,10 @@ type StateInterface interface { GetLatestBatchGlobalExitRoot(ctx context.Context, dbTx pgx.Tx) (common.Hash, error) GetL2TxHashByTxHash(ctx context.Context, hash common.Hash, dbTx pgx.Tx) (*common.Hash, error) PreProcessUnsignedTransaction(ctx context.Context, tx *types.Transaction, sender common.Address, l2BlockNumber *uint64, dbTx pgx.Tx) (*state.ProcessBatchResponse, error) + GetCurrentForkID(ctx context.Context, dbTx pgx.Tx) (uint64, error) + GetForkIDByBatchNumber(batchNumber uint64) uint64 + GetForkByID(ctx context.Context, forkID uint64, dbTx pgx.Tx) (*state.ForkIDInterval, error) + GetForkIDIntervals(ctx context.Context, dbTx pgx.Tx) ([]state.ForkIDInterval, error) } // EthermanInterface provides integration with L1 diff --git a/jsonrpc/types/types.go b/jsonrpc/types/types.go index b9c902cc1a..dae0a48367 100644 --- a/jsonrpc/types/types.go +++ b/jsonrpc/types/types.go @@ -773,3 +773,24 @@ func NewZKCountersResponse(zkCounters state.ZKCounters, limits ZKCountersLimits, OOCError: oocErrMsg, } } + +// ForkIDInterval provides fork id information +type ForkIDInterval struct { + ForkId ArgUint64 `json:"forkId"` + FromBatchNumber ArgUint64 `json:"fromBatchNumber"` + ToBatchNumber ArgUint64 `json:"toBatchNumber"` + Version string `json:"version"` + BlockNumber ArgUint64 `json:"blockNumber"` +} + +// NewForkIDInterval creates a new instance of ForkIDInterval +// given a state.ForkIDInterval instance +func NewForkIDInterval(forkIDInterval state.ForkIDInterval) *ForkIDInterval { + return &ForkIDInterval{ + ForkId: ArgUint64(forkIDInterval.ForkId), + FromBatchNumber: ArgUint64(forkIDInterval.FromBatchNumber), + ToBatchNumber: ArgUint64(forkIDInterval.ToBatchNumber), + Version: forkIDInterval.Version, + BlockNumber: ArgUint64(forkIDInterval.BlockNumber), + } +} diff --git a/state/forkid.go b/state/forkid.go index ed035a53e1..08d520c352 100644 --- a/state/forkid.go +++ b/state/forkid.go @@ -49,3 +49,18 @@ func (s *State) GetForkIDByBatchNumber(batchNumber uint64) uint64 { func (s *State) GetForkIDByBlockNumber(blockNumber uint64) uint64 { return s.storage.GetForkIDByBlockNumber(blockNumber) } + +// GetCurrentForkID gets the current fork id +func (s *State) GetCurrentForkID(ctx context.Context, dbTx pgx.Tx) (uint64, error) { + return s.storage.GetCurrentForkID(ctx, dbTx) +} + +// GetForkByID gets the fork id interval by fork id number +func (s *State) GetForkByID(ctx context.Context, forkID uint64, dbTx pgx.Tx) (*ForkIDInterval, error) { + return s.storage.GetForkByID(ctx, forkID, dbTx) +} + +// GetForkIDIntervals gets all fork id intervals +func (s *State) GetForkIDIntervals(ctx context.Context, dbTx pgx.Tx) ([]ForkIDInterval, error) { + return s.storage.GetForkIDs(ctx, dbTx) +} diff --git a/state/interfaces.go b/state/interfaces.go index 33e8bc01be..59d4fc503f 100644 --- a/state/interfaces.go +++ b/state/interfaces.go @@ -123,6 +123,8 @@ type storage interface { GetLatestGer(ctx context.Context, maxBlockNumber uint64) (GlobalExitRoot, time.Time, error) GetBatchByForcedBatchNum(ctx context.Context, forcedBatchNumber uint64, dbTx pgx.Tx) (*Batch, error) AddForkID(ctx context.Context, forkID ForkIDInterval, dbTx pgx.Tx) error + GetCurrentForkID(ctx context.Context, dbTx pgx.Tx) (uint64, error) + GetForkByID(ctx context.Context, forkID uint64, dbTx pgx.Tx) (*ForkIDInterval, error) GetForkIDs(ctx context.Context, dbTx pgx.Tx) ([]ForkIDInterval, error) UpdateForkIDToBatchNumber(ctx context.Context, forkID ForkIDInterval, dbTx pgx.Tx) error UpdateForkIDBlockNumber(ctx context.Context, forkdID uint64, newBlockNumber uint64, updateMemCache bool, dbTx pgx.Tx) error diff --git a/state/mocks/mock_storage.go b/state/mocks/mock_storage.go index 57b72a61f6..b1ba7cd893 100644 --- a/state/mocks/mock_storage.go +++ b/state/mocks/mock_storage.go @@ -1950,6 +1950,63 @@ func (_c *StorageMock_GetBlockNumVirtualBatchByBatchNum_Call) RunAndReturn(run f return _c } +// GetCurrentForkID provides a mock function with given fields: ctx, dbTx +func (_m *StorageMock) GetCurrentForkID(ctx context.Context, dbTx pgx.Tx) (uint64, error) { + ret := _m.Called(ctx, dbTx) + + if len(ret) == 0 { + panic("no return value specified for GetCurrentForkID") + } + + var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, pgx.Tx) (uint64, error)); ok { + return rf(ctx, dbTx) + } + if rf, ok := ret.Get(0).(func(context.Context, pgx.Tx) uint64); ok { + r0 = rf(ctx, dbTx) + } else { + r0 = ret.Get(0).(uint64) + } + + if rf, ok := ret.Get(1).(func(context.Context, pgx.Tx) error); ok { + r1 = rf(ctx, dbTx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// StorageMock_GetCurrentForkID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCurrentForkID' +type StorageMock_GetCurrentForkID_Call struct { + *mock.Call +} + +// GetCurrentForkID is a helper method to define mock.On call +// - ctx context.Context +// - dbTx pgx.Tx +func (_e *StorageMock_Expecter) GetCurrentForkID(ctx interface{}, dbTx interface{}) *StorageMock_GetCurrentForkID_Call { + return &StorageMock_GetCurrentForkID_Call{Call: _e.mock.On("GetCurrentForkID", ctx, dbTx)} +} + +func (_c *StorageMock_GetCurrentForkID_Call) Run(run func(ctx context.Context, dbTx pgx.Tx)) *StorageMock_GetCurrentForkID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(pgx.Tx)) + }) + return _c +} + +func (_c *StorageMock_GetCurrentForkID_Call) Return(_a0 uint64, _a1 error) *StorageMock_GetCurrentForkID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *StorageMock_GetCurrentForkID_Call) RunAndReturn(run func(context.Context, pgx.Tx) (uint64, error)) *StorageMock_GetCurrentForkID_Call { + _c.Call.Return(run) + return _c +} + // GetDSBatches provides a mock function with given fields: ctx, firstBatchNumber, lastBatchNumber, readWIPBatch, dbTx func (_m *StorageMock) GetDSBatches(ctx context.Context, firstBatchNumber uint64, lastBatchNumber uint64, readWIPBatch bool, dbTx pgx.Tx) ([]*state.DSBatch, error) { ret := _m.Called(ctx, firstBatchNumber, lastBatchNumber, readWIPBatch, dbTx) @@ -2621,6 +2678,66 @@ func (_c *StorageMock_GetForcedBatchesSince_Call) RunAndReturn(run func(context. return _c } +// GetForkByID provides a mock function with given fields: ctx, forkID, dbTx +func (_m *StorageMock) GetForkByID(ctx context.Context, forkID uint64, dbTx pgx.Tx) (*state.ForkIDInterval, error) { + ret := _m.Called(ctx, forkID, dbTx) + + if len(ret) == 0 { + panic("no return value specified for GetForkByID") + } + + var r0 *state.ForkIDInterval + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uint64, pgx.Tx) (*state.ForkIDInterval, error)); ok { + return rf(ctx, forkID, dbTx) + } + if rf, ok := ret.Get(0).(func(context.Context, uint64, pgx.Tx) *state.ForkIDInterval); ok { + r0 = rf(ctx, forkID, dbTx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*state.ForkIDInterval) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, uint64, pgx.Tx) error); ok { + r1 = rf(ctx, forkID, dbTx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// StorageMock_GetForkByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetForkByID' +type StorageMock_GetForkByID_Call struct { + *mock.Call +} + +// GetForkByID is a helper method to define mock.On call +// - ctx context.Context +// - forkID uint64 +// - dbTx pgx.Tx +func (_e *StorageMock_Expecter) GetForkByID(ctx interface{}, forkID interface{}, dbTx interface{}) *StorageMock_GetForkByID_Call { + return &StorageMock_GetForkByID_Call{Call: _e.mock.On("GetForkByID", ctx, forkID, dbTx)} +} + +func (_c *StorageMock_GetForkByID_Call) Run(run func(ctx context.Context, forkID uint64, dbTx pgx.Tx)) *StorageMock_GetForkByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uint64), args[2].(pgx.Tx)) + }) + return _c +} + +func (_c *StorageMock_GetForkByID_Call) Return(_a0 *state.ForkIDInterval, _a1 error) *StorageMock_GetForkByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *StorageMock_GetForkByID_Call) RunAndReturn(run func(context.Context, uint64, pgx.Tx) (*state.ForkIDInterval, error)) *StorageMock_GetForkByID_Call { + _c.Call.Return(run) + return _c +} + // GetForkIDByBatchNumber provides a mock function with given fields: batchNumber func (_m *StorageMock) GetForkIDByBatchNumber(batchNumber uint64) uint64 { ret := _m.Called(batchNumber) diff --git a/state/pgstatestorage/forkid.go b/state/pgstatestorage/forkid.go index dbe865bc4f..aac3ffc56a 100644 --- a/state/pgstatestorage/forkid.go +++ b/state/pgstatestorage/forkid.go @@ -152,7 +152,7 @@ func (p *PostgresStorage) GetForkIDByBlockNumber(blockNumber uint64) uint64 { // GetForkIDByBlockNumber returns the fork id for a given block number in memory func (p *PostgresStorage) GetForkIDByBlockNumberInMemory(blockNumber uint64) uint64 { - for _, index := range sortIndexForForkdIDSortedByBlockNumber(p.cfg.ForkIDIntervals) { + for _, index := range sortIndexForForkIDSortedByBlockNumber(p.cfg.ForkIDIntervals) { // reverse travesal interval := p.cfg.ForkIDIntervals[len(p.cfg.ForkIDIntervals)-1-index] if blockNumber >= interval.BlockNumber { @@ -163,7 +163,7 @@ func (p *PostgresStorage) GetForkIDByBlockNumberInMemory(blockNumber uint64) uin return 1 } -func sortIndexForForkdIDSortedByBlockNumber(forkIDs []state.ForkIDInterval) []int { +func sortIndexForForkIDSortedByBlockNumber(forkIDs []state.ForkIDInterval) []int { sortedIndex := make([]int, len(forkIDs)) for i := range sortedIndex { sortedIndex[i] = i @@ -240,6 +240,37 @@ func (p *PostgresStorage) GetForkIDByBatchNumberInMemory(batchNumber uint64) uin return p.cfg.ForkIDIntervals[len(p.cfg.ForkIDIntervals)-1].ForkId } +// GetForkByID returns the fork id interval for a given fork id +func (p *PostgresStorage) GetForkByID(ctx context.Context, forkId uint64, dbTx pgx.Tx) (*state.ForkIDInterval, error) { + if p.cfg.AvoidForkIDInMemory { + const getForkIDsSQL = "SELECT from_batch_num, to_batch_num, fork_id, version, block_num FROM state.fork_id WHERE fork_id = $1" + q := p.getExecQuerier(dbTx) + + forkIDInterval := &state.ForkIDInterval{} + + err := q.QueryRow(ctx, getForkIDsSQL, forkId).Scan( + &forkIDInterval.FromBatchNumber, + &forkIDInterval.ToBatchNumber, + &forkIDInterval.ForkId, + &forkIDInterval.Version, + &forkIDInterval.BlockNumber, + ) + if errors.Is(err, pgx.ErrNoRows) { + return nil, state.ErrNotFound + } else if err != nil { + return nil, err + } + + return forkIDInterval, nil + } else { + forkIdInterval := p.GetForkIDInMemory(forkId) + if p.GetForkIDInMemory(forkId) == nil { + return nil, state.ErrNotFound + } + return forkIdInterval, nil + } +} + // GetForkIDInMemory get the forkIDs stored in cache, or nil if not found func (p *PostgresStorage) GetForkIDInMemory(forkId uint64) *state.ForkIDInterval { for _, interval := range p.cfg.ForkIDIntervals { @@ -249,3 +280,14 @@ func (p *PostgresStorage) GetForkIDInMemory(forkId uint64) *state.ForkIDInterval } return nil } + +// GetCurrentForkID gets the current fork id +func (p *PostgresStorage) GetCurrentForkID(ctx context.Context, dbTx pgx.Tx) (uint64, error) { + lastBatchNumber, err := p.GetLastBatchNumber(ctx, dbTx) + if err != nil { + return 0, err + } + + forkID := p.GetForkIDByBatchNumber(lastBatchNumber) + return forkID, nil +} diff --git a/state/pgstatestorage/forkid_test.go b/state/pgstatestorage/forkid_test.go index ff698be11b..45a91da618 100644 --- a/state/pgstatestorage/forkid_test.go +++ b/state/pgstatestorage/forkid_test.go @@ -17,7 +17,7 @@ func TestSortIndexForForkdIDSortedByBlockNumber(t *testing.T) { } expected := []int{3, 1, 0, 2} - actual := sortIndexForForkdIDSortedByBlockNumber(forkIDs) + actual := sortIndexForForkIDSortedByBlockNumber(forkIDs) assert.Equal(t, expected, actual)