diff --git a/fuzz/testing/BUILD.bazel b/fuzz/testing/BUILD.bazel index 555426beebef..5e1888eed8f4 100644 --- a/fuzz/testing/BUILD.bazel +++ b/fuzz/testing/BUILD.bazel @@ -22,4 +22,5 @@ go_test( name = "go_default_test", srcs = ["beacon_fuzz_states_test.go"], embed = [":go_default_library"], + deps = ["//shared/testutil/require:go_default_library"], ) diff --git a/fuzz/testing/beacon_fuzz_states_test.go b/fuzz/testing/beacon_fuzz_states_test.go index 827276009c1a..836956d4672f 100644 --- a/fuzz/testing/beacon_fuzz_states_test.go +++ b/fuzz/testing/beacon_fuzz_states_test.go @@ -2,10 +2,11 @@ package testing import ( "testing" + + "github.com/prysmaticlabs/prysm/shared/testutil/require" ) func TestGetBeaconFuzzState(t *testing.T) { - if _, err := GetBeaconFuzzState(1); err != nil { - t.Fatal(err) - } + _, err := GetBeaconFuzzState(1) + require.NoError(t, err) } diff --git a/slasher/BUILD.bazel b/slasher/BUILD.bazel index f9d2130d8cea..70ecb4cd2b70 100644 --- a/slasher/BUILD.bazel +++ b/slasher/BUILD.bazel @@ -35,6 +35,7 @@ go_test( embed = [":go_default_library"], deps = [ "//shared/featureconfig:go_default_library", + "//shared/testutil/assert:go_default_library", "@com_github_urfave_cli_v2//:go_default_library", ], ) diff --git a/slasher/beaconclient/BUILD.bazel b/slasher/beaconclient/BUILD.bazel index f058ac32e067..39063000207d 100644 --- a/slasher/beaconclient/BUILD.bazel +++ b/slasher/beaconclient/BUILD.bazel @@ -59,10 +59,10 @@ go_test( "//shared/mock:go_default_library", "//shared/params:go_default_library", "//shared/slotutil:go_default_library", + "//shared/testutil/assert:go_default_library", "//shared/testutil/require:go_default_library", "//slasher/cache:go_default_library", "//slasher/db/testing:go_default_library", - "@com_github_gogo_protobuf//proto:go_default_library", "@com_github_gogo_protobuf//types:go_default_library", "@com_github_golang_mock//gomock:go_default_library", "@com_github_prysmaticlabs_ethereumapis//eth/v1alpha1:go_default_library", diff --git a/slasher/beaconclient/chain_data_test.go b/slasher/beaconclient/chain_data_test.go index 68b8a13e9e0d..fa9d9a7ec1f3 100644 --- a/slasher/beaconclient/chain_data_test.go +++ b/slasher/beaconclient/chain_data_test.go @@ -1,15 +1,14 @@ package beaconclient import ( - "bytes" "context" "testing" "time" - "github.com/gogo/protobuf/proto" "github.com/golang/mock/gomock" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" "github.com/prysmaticlabs/prysm/shared/mock" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" "github.com/prysmaticlabs/prysm/shared/testutil/require" logTest "github.com/sirupsen/logrus/hooks/test" ) @@ -29,12 +28,8 @@ func TestService_ChainHead(t *testing.T) { } client.EXPECT().GetChainHead(gomock.Any(), gomock.Any()).Return(wanted, nil) res, err := bs.ChainHead(context.Background()) - if err != nil { - t.Fatal(err) - } - if !proto.Equal(res, wanted) { - t.Errorf("Wanted %v, received %v", wanted, res) - } + require.NoError(t, err) + require.DeepEqual(t, wanted, res) } func TestService_GenesisValidatorsRoot(t *testing.T) { @@ -50,20 +45,12 @@ func TestService_GenesisValidatorsRoot(t *testing.T) { } client.EXPECT().GetGenesis(gomock.Any(), gomock.Any()).Return(wanted, nil) res, err := bs.GenesisValidatorsRoot(context.Background()) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(res, wanted.GenesisValidatorsRoot) { - t.Errorf("Wanted %#x, received %#x", wanted.GenesisValidatorsRoot, res) - } + require.NoError(t, err) + assert.DeepEqual(t, wanted.GenesisValidatorsRoot, res, "Wanted %#x, received %#x", wanted.GenesisValidatorsRoot, res) // test next fetch uses memory and not the rpc call. res, err = bs.GenesisValidatorsRoot(context.Background()) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(res, wanted.GenesisValidatorsRoot) { - t.Errorf("Wanted %#x, received %#x", wanted.GenesisValidatorsRoot, res) - } + require.NoError(t, err) + assert.DeepEqual(t, wanted.GenesisValidatorsRoot, res, "Wanted %#x, received %#x", wanted.GenesisValidatorsRoot, res) } func TestService_QuerySyncStatus(t *testing.T) { diff --git a/slasher/beaconclient/historical_data_retrieval_test.go b/slasher/beaconclient/historical_data_retrieval_test.go index 5b9f222ddd54..2f4a8f4afe0d 100644 --- a/slasher/beaconclient/historical_data_retrieval_test.go +++ b/slasher/beaconclient/historical_data_retrieval_test.go @@ -2,7 +2,6 @@ package beaconclient import ( "context" - "reflect" "strconv" "testing" @@ -10,6 +9,7 @@ import ( ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" "github.com/prysmaticlabs/prysm/shared/mock" "github.com/prysmaticlabs/prysm/shared/params" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" "github.com/prysmaticlabs/prysm/shared/testutil/require" testDB "github.com/prysmaticlabs/prysm/slasher/db/testing" logTest "github.com/sirupsen/logrus/hooks/test" @@ -77,12 +77,8 @@ func TestService_RequestHistoricalAttestations(t *testing.T) { // We request attestations for epoch 0. res, err := bs.RequestHistoricalAttestations(context.Background(), 0) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(res, wanted) { - t.Errorf("Wanted %v, received %v", wanted, res) - } + require.NoError(t, err) + assert.DeepEqual(t, wanted, res) require.LogsContain(t, hook, "Retrieved 100/1000 indexed attestations for epoch 0") require.LogsContain(t, hook, "Retrieved 500/1000 indexed attestations for epoch 0") require.LogsContain(t, hook, "Retrieved 1000/1000 indexed attestations for epoch 0") diff --git a/slasher/beaconclient/receivers_test.go b/slasher/beaconclient/receivers_test.go index 11fc54bb49d8..e0abb7a908e5 100644 --- a/slasher/beaconclient/receivers_test.go +++ b/slasher/beaconclient/receivers_test.go @@ -11,6 +11,7 @@ import ( "github.com/prysmaticlabs/prysm/shared/event" "github.com/prysmaticlabs/prysm/shared/mock" "github.com/prysmaticlabs/prysm/shared/slotutil" + "github.com/prysmaticlabs/prysm/shared/testutil/require" testDB "github.com/prysmaticlabs/prysm/slasher/db/testing" ) @@ -117,7 +118,5 @@ func TestService_ReceiveAttestations_Batched(t *testing.T) { att.Data.Target.Root = []byte("test root 3") bs.receivedAttestationsBuffer <- att atts := <-bs.collectedAttestationsBuffer - if len(atts) != 3 { - t.Fatalf("Expected %d received attestations to be batched", len(atts)) - } + require.Equal(t, 3, len(atts), "Unexpected number of attestations batched") } diff --git a/slasher/beaconclient/service_test.go b/slasher/beaconclient/service_test.go index bd997a6a507c..8dd801c112e8 100644 --- a/slasher/beaconclient/service_test.go +++ b/slasher/beaconclient/service_test.go @@ -1,6 +1,21 @@ package beaconclient +import ( + "io/ioutil" + "os" + "testing" + + "github.com/sirupsen/logrus" +) + var ( _ = Notifier(&Service{}) _ = ChainFetcher(&Service{}) ) + +func TestMain(m *testing.M) { + logrus.SetLevel(logrus.DebugLevel) + logrus.SetOutput(ioutil.Discard) + + os.Exit(m.Run()) +} diff --git a/slasher/beaconclient/validator_retrieval_test.go b/slasher/beaconclient/validator_retrieval_test.go index da983a7a5d96..ba41772db4d8 100644 --- a/slasher/beaconclient/validator_retrieval_test.go +++ b/slasher/beaconclient/validator_retrieval_test.go @@ -1,13 +1,13 @@ package beaconclient import ( - "bytes" "context" "testing" "github.com/golang/mock/gomock" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" "github.com/prysmaticlabs/prysm/shared/mock" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" "github.com/prysmaticlabs/prysm/shared/testutil/require" "github.com/prysmaticlabs/prysm/slasher/cache" "github.com/sirupsen/logrus" @@ -21,9 +21,7 @@ func TestService_RequestValidator(t *testing.T) { defer ctrl.Finish() client := mock.NewMockBeaconChainClient(ctrl) validatorCache, err := cache.NewPublicKeyCache(0, nil) - if err != nil { - t.Fatalf("could not create new cache: %v", err) - } + require.NoError(t, err, "Could not create new cache") bs := Service{ beaconClient: client, publicKeyCache: validatorCache, @@ -57,27 +55,19 @@ func TestService_RequestValidator(t *testing.T) { // We request public key of validator id 0,1. res, err := bs.FindOrGetPublicKeys(context.Background(), []uint64{0, 1}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) for i, v := range wanted.ValidatorList { - if !bytes.Equal(res[v.Index], wanted.ValidatorList[i].Validator.PublicKey) { - t.Errorf("Wanted %v, received %v", wanted, res) - } + assert.DeepEqual(t, wanted.ValidatorList[i].Validator.PublicKey, res[v.Index]) } require.LogsContain(t, hook, "Retrieved validators id public key map:") require.LogsDoNotContain(t, hook, "Retrieved validators public keys from cache:") // We expect public key of validator id 0 to be in cache. res, err = bs.FindOrGetPublicKeys(context.Background(), []uint64{0, 3}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) for i, v := range wanted2.ValidatorList { - if !bytes.Equal(res[v.Index], wanted2.ValidatorList[i].Validator.PublicKey) { - t.Errorf("Wanted %v, received %v", wanted2, res) - } + assert.DeepEqual(t, wanted2.ValidatorList[i].Validator.PublicKey, res[v.Index]) } require.LogsContain(t, hook, "Retrieved validators public keys from cache: map[0:[1 2 3]]") } diff --git a/slasher/db/kv/BUILD.bazel b/slasher/db/kv/BUILD.bazel index b8f0632344ee..9ae5b5018e00 100644 --- a/slasher/db/kv/BUILD.bazel +++ b/slasher/db/kv/BUILD.bazel @@ -56,10 +56,12 @@ go_test( "//shared/bytesutil:go_default_library", "//shared/params:go_default_library", "//shared/testutil:go_default_library", + "//shared/testutil/assert:go_default_library", + "//shared/testutil/require:go_default_library", "//slasher/db/types:go_default_library", "//slasher/detection/attestations/types:go_default_library", - "@com_github_gogo_protobuf//proto:go_default_library", "@com_github_prysmaticlabs_ethereumapis//eth/v1alpha1:go_default_library", + "@com_github_sirupsen_logrus//:go_default_library", "@com_github_urfave_cli_v2//:go_default_library", "@in_gopkg_d4l3k_messagediff_v1//:go_default_library", ], diff --git a/slasher/db/kv/attester_slashings_test.go b/slasher/db/kv/attester_slashings_test.go index 9d6c210b1d69..f2d2eaf6b562 100644 --- a/slasher/db/kv/attester_slashings_test.go +++ b/slasher/db/kv/attester_slashings_test.go @@ -3,12 +3,12 @@ package kv import ( "context" "flag" - "reflect" "sort" "testing" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" "github.com/prysmaticlabs/prysm/shared/bytesutil" + "github.com/prysmaticlabs/prysm/shared/testutil/require" "github.com/prysmaticlabs/prysm/slasher/db/types" "github.com/urfave/cli/v2" ) @@ -21,20 +21,13 @@ func TestStore_AttesterSlashingNilBucket(t *testing.T) { as := ðpb.AttesterSlashing{Attestation_1: ðpb.IndexedAttestation{Signature: bytesutil.PadTo([]byte("hello"), 96)}} has, _, err := db.HasAttesterSlashing(ctx, as) - if err != nil { - t.Fatalf("HasAttesterSlashing should not return error: %v", err) - } - if has { - t.Fatal("HasAttesterSlashing should return false") - } + require.NoError(t, err, "HasAttesterSlashing should not return error") + require.Equal(t, false, has) p, err := db.AttesterSlashings(ctx, types.SlashingStatus(types.Active)) - if err != nil { - t.Fatalf("Failed to get attester slashing: %v", err) - } - if p == nil || len(p) != 0 { - t.Fatalf("Get should return empty attester slashing array for a non existent key") - } + require.NoError(t, err, "Failed to get attester slashing") + require.NotNil(t, p, "Get should return empty attester slashing array for a non existent key") + require.Equal(t, 0, len(p), "Get should return empty attester slashing array for a non existent key") } func TestStore_SaveAttesterSlashing(t *testing.T) { @@ -68,20 +61,13 @@ func TestStore_SaveAttesterSlashing(t *testing.T) { for _, tt := range tests { err := db.SaveAttesterSlashing(ctx, tt.ss, tt.as) - if err != nil { - t.Fatalf("save attester slashing failed: %v", err) - } + require.NoError(t, err, "Save attester slashing failed") attesterSlashings, err := db.AttesterSlashings(ctx, tt.ss) - if err != nil { - t.Fatalf("failed to get attester slashings: %v", err) - } - - if attesterSlashings == nil || !reflect.DeepEqual(attesterSlashings[0], tt.as) { - t.Fatalf("attester slashing: %v should be part of attester slashings response: %v", tt.as, attesterSlashings) - } + require.NoError(t, err, "Failed to get attester slashings") + require.NotNil(t, attesterSlashings) + require.DeepEqual(t, tt.as, attesterSlashings[0], "Slashing: %v should be part of slashings response: %v", tt.as, attesterSlashings) } - } func TestStore_SaveAttesterSlashings(t *testing.T) { @@ -99,19 +85,14 @@ func TestStore_SaveAttesterSlashings(t *testing.T) { {Attestation_1: ðpb.IndexedAttestation{Signature: bytesutil.PadTo([]byte("3"), 96), Data: data}, Attestation_2: att}, } err := db.SaveAttesterSlashings(ctx, types.Active, as) - if err != nil { - t.Fatalf("save attester slashing failed: %v", err) - } + require.NoError(t, err, "Save attester slashing failed") attesterSlashings, err := db.AttesterSlashings(ctx, types.Active) - if err != nil { - t.Fatalf("failed to get attester slashings: %v", err) - } + require.NoError(t, err, "Failed to get attester slashings") sort.SliceStable(attesterSlashings, func(i, j int) bool { return attesterSlashings[i].Attestation_1.Signature[0] < attesterSlashings[j].Attestation_1.Signature[0] }) - if attesterSlashings == nil || !reflect.DeepEqual(attesterSlashings, as) { - t.Fatalf("Attester slashing: %v should be part of attester slashings response: %v", as, attesterSlashings) - } + require.NotNil(t, attesterSlashings) + require.DeepEqual(t, as, attesterSlashings, "Slashing: %v should be part of slashings response: %v", as, attesterSlashings) } func TestStore_UpdateAttesterSlashingStatus(t *testing.T) { @@ -140,34 +121,20 @@ func TestStore_UpdateAttesterSlashingStatus(t *testing.T) { for _, tt := range tests { err := db.SaveAttesterSlashing(ctx, tt.ss, tt.as) - if err != nil { - t.Fatalf("save attester slashing failed: %v", err) - } + require.NoError(t, err, "Save attester slashing failed") } for _, tt := range tests { has, st, err := db.HasAttesterSlashing(ctx, tt.as) - if err != nil { - t.Fatalf("Failed to get attester slashing: %v", err) - } - if !has { - t.Fatalf("Failed to find attester slashing: %v", tt.as) - } - if st != tt.ss { - t.Fatalf("Failed to find attester slashing with the correct status: %v", tt.as) - } + require.NoError(t, err, "Failed to get attester slashing") + require.Equal(t, true, has, "Failed to find attester slashing: %v", tt.as) + require.Equal(t, tt.ss, st, "Failed to find attester slashing with the correct status: %v", tt.as) err = db.SaveAttesterSlashing(ctx, types.SlashingStatus(types.Included), tt.as) has, st, err = db.HasAttesterSlashing(ctx, tt.as) - if err != nil { - t.Fatalf("Failed to get attester slashing: %v", err) - } - if !has { - t.Fatalf("Failed to find attester slashing: %v", tt.as) - } - if st != types.Included { - t.Fatalf("Failed to find attester slashing with the correct status: %v", tt.as) - } + require.NoError(t, err, "Failed to get attester slashing") + require.Equal(t, true, has, "Failed to find attester slashing: %v", tt.as) + require.Equal(t, (types.SlashingStatus)(types.Included), st, "Failed to find attester slashing with the correct status: %v", tt.as) } } @@ -178,22 +145,12 @@ func TestStore_LatestEpochDetected(t *testing.T) { ctx := context.Background() e, err := db.GetLatestEpochDetected(ctx) - if err != nil { - t.Fatalf("Get latest epoch detected failed: %v", err) - } - if e != 0 { - t.Fatalf("Latest epoch detected should have been 0 before setting got: %d", e) - } + require.NoError(t, err, "Get latest epoch detected failed") + require.Equal(t, uint64(0), e, "Latest epoch detected should have been 0 before setting got: %d", e) epoch := uint64(1) err = db.SetLatestEpochDetected(ctx, epoch) - if err != nil { - t.Fatalf("Set latest epoch detected failed: %v", err) - } + require.NoError(t, err, "Set latest epoch detected failed") e, err = db.GetLatestEpochDetected(ctx) - if err != nil { - t.Fatalf("Get latest epoch detected failed: %v", err) - } - if e != epoch { - t.Fatalf("Latest epoch detected should have been: %d got: %d", epoch, e) - } + require.NoError(t, err, "Get latest epoch detected failed") + require.Equal(t, epoch, e, "Latest epoch detected should have been: %d got: %d", epoch, e) } diff --git a/slasher/db/kv/benchmark_test.go b/slasher/db/kv/benchmark_test.go index 18f85dc4127e..9abf6c464f7f 100644 --- a/slasher/db/kv/benchmark_test.go +++ b/slasher/db/kv/benchmark_test.go @@ -5,6 +5,8 @@ import ( "flag" "testing" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" + "github.com/prysmaticlabs/prysm/shared/testutil/require" "github.com/prysmaticlabs/prysm/slasher/detection/attestations/types" "github.com/urfave/cli/v2" ) @@ -22,22 +24,16 @@ func BenchmarkStore_SaveEpochSpans(b *testing.B) { es := &types.EpochStore{} es, err := es.SetValidatorSpan(benchmarkValidator, types.Span{MinSpan: 1, MaxSpan: 2, SigBytes: sigBytes, HasAttested: true}) - if err != nil { - b.Error(err) - } + assert.NoError(b, err) for i := 0; i < benchmarkValidator; i++ { es, err = es.SetValidatorSpan(uint64(i), types.Span{MinSpan: 1, MaxSpan: 2, SigBytes: sigBytes, HasAttested: true}) - if err != nil { - b.Error(err) - } + assert.NoError(b, err) } b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { err := db.SaveEpochSpans(ctx, uint64(i%54000), es, false) - if err != nil { - b.Fatalf("Save validator span map failed: %v", err) - } + require.NoError(b, err, "Save validator span map failed") } } @@ -49,30 +45,22 @@ func BenchmarkStore_EpochSpans(b *testing.B) { sigBytes := [2]byte{} es := &types.EpochStore{} es, err := es.SetValidatorSpan(benchmarkValidator, types.Span{MinSpan: 1, MaxSpan: 2, SigBytes: sigBytes, HasAttested: true}) - if err != nil { - b.Error(err) - } + assert.NoError(b, err) for i := 0; i < benchmarkValidator; i++ { es, err = es.SetValidatorSpan(uint64(i), types.Span{MinSpan: 1, MaxSpan: 2, SigBytes: sigBytes, HasAttested: true}) - if err != nil { - b.Error(err) - } + assert.NoError(b, err) } b.Log(len(es.Bytes())) for i := 0; i < 200; i++ { err := db.SaveEpochSpans(ctx, uint64(i), es, false) - if err != nil { - b.Fatalf("Save validator span map failed: %v", err) - } + require.NoError(b, err, "Save validator span map failed") } b.Log(db.db.Info()) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { _, err := db.EpochSpans(ctx, uint64(i%200), false) - if err != nil { - b.Fatalf("Read validator span map failed: %v", err) - } + require.NoError(b, err, "Read validator span map failed") } } @@ -80,14 +68,10 @@ func BenchmarkStore_GetValidatorSpan(b *testing.B) { sigBytes := [2]byte{} es := &types.EpochStore{} es, err := es.SetValidatorSpan(benchmarkValidator, types.Span{MinSpan: 1, MaxSpan: 2, SigBytes: sigBytes, HasAttested: true}) - if err != nil { - b.Error(err) - } + assert.NoError(b, err) for i := 0; i < benchmarkValidator; i++ { es, err = es.SetValidatorSpan(uint64(i), types.Span{MinSpan: uint16(i), MaxSpan: uint16(benchmarkValidator - i), SigBytes: sigBytes, HasAttested: true}) - if err != nil { - b.Error(err) - } + assert.NoError(b, err) } b.Log(len(es.Bytes())) @@ -95,9 +79,7 @@ func BenchmarkStore_GetValidatorSpan(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { _, err := es.GetValidatorSpan(uint64(i % benchmarkValidator)) - if err != nil { - b.Fatalf("Read validator span map failed: %v", err) - } + require.NoError(b, err, "Read validator span map failed") } } @@ -106,23 +88,17 @@ func BenchmarkStore_SetValidatorSpan(b *testing.B) { var err error es := &types.EpochStore{} es, err = es.SetValidatorSpan(benchmarkValidator, types.Span{MinSpan: 1, MaxSpan: 2, SigBytes: sigBytes, HasAttested: true}) - if err != nil { - b.Error(err) - } + assert.NoError(b, err) for i := 0; i < benchmarkValidator; i++ { es, err = es.SetValidatorSpan(uint64(i), types.Span{MinSpan: uint16(i), MaxSpan: uint16(benchmarkValidator - i), SigBytes: sigBytes, HasAttested: true}) - if err != nil { - b.Error(err) - } + assert.NoError(b, err) } b.Log(len(es.Bytes())) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { es, err = es.SetValidatorSpan(uint64(i%benchmarkValidator), types.Span{MinSpan: uint16(i), MaxSpan: uint16(benchmarkValidator - i), SigBytes: sigBytes, HasAttested: true}) - if err != nil { - b.Fatalf("Read validator span map failed: %v", err) - } + require.NoError(b, err, "Read validator span map failed") } } diff --git a/slasher/db/kv/block_header_test.go b/slasher/db/kv/block_header_test.go index ad2faced7da0..f80ce4642a77 100644 --- a/slasher/db/kv/block_header_test.go +++ b/slasher/db/kv/block_header_test.go @@ -3,12 +3,13 @@ package kv import ( "context" "flag" - "reflect" "testing" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" "github.com/prysmaticlabs/prysm/beacon-chain/core/helpers" "github.com/prysmaticlabs/prysm/shared/params" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" + "github.com/prysmaticlabs/prysm/shared/testutil/require" "github.com/urfave/cli/v2" ) @@ -21,18 +22,11 @@ func TestNilDBHistoryBlkHdr(t *testing.T) { slot := uint64(1) validatorID := uint64(1) - hasBlockHeader := db.HasBlockHeader(ctx, slot, validatorID) - if hasBlockHeader { - t.Fatal("HasBlockHeader should return false") - } + require.Equal(t, false, db.HasBlockHeader(ctx, slot, validatorID)) bPrime, err := db.BlockHeaders(ctx, slot, validatorID) - if err != nil { - t.Fatalf("failed to get block: %v", err) - } - if bPrime != nil { - t.Fatalf("get should return nil for a non existent key") - } + require.NoError(t, err, "Failed to get block") + require.DeepEqual(t, ([]*ethpb.SignedBeaconBlockHeader)(nil), bPrime, "Should return nil for a non existent key") } func TestSaveHistoryBlkHdr(t *testing.T) { @@ -60,20 +54,13 @@ func TestSaveHistoryBlkHdr(t *testing.T) { for _, tt := range tests { err := db.SaveBlockHeader(ctx, tt.bh) - if err != nil { - t.Fatalf("save block failed: %v", err) - } + require.NoError(t, err, "Save block failed") bha, err := db.BlockHeaders(ctx, tt.bh.Header.Slot, tt.bh.Header.ProposerIndex) - if err != nil { - t.Fatalf("failed to get block: %v", err) - } - - if bha == nil || !reflect.DeepEqual(bha[0], tt.bh) { - t.Fatalf("get should return bh: %v", bha) - } + require.NoError(t, err, "Failed to get block") + require.NotNil(t, bha) + require.DeepEqual(t, tt.bh, bha[0], "Should return bh") } - } func TestDeleteHistoryBlkHdr(t *testing.T) { @@ -96,37 +83,22 @@ func TestDeleteHistoryBlkHdr(t *testing.T) { }, } for _, tt := range tests { - err := db.SaveBlockHeader(ctx, tt.bh) - if err != nil { - t.Fatalf("save block failed: %v", err) - } + require.NoError(t, err, "Save block failed") } for _, tt := range tests { bha, err := db.BlockHeaders(ctx, tt.bh.Header.Slot, tt.bh.Header.ProposerIndex) - if err != nil { - t.Fatalf("failed to get block: %v", err) - } + require.NoError(t, err, "Failed to get block") + require.NotNil(t, bha) + require.DeepEqual(t, tt.bh, bha[0], "Should return bh") - if bha == nil || !reflect.DeepEqual(bha[0], tt.bh) { - t.Fatalf("get should return bh: %v", bha) - } err = db.DeleteBlockHeader(ctx, tt.bh) - if err != nil { - t.Fatalf("save block failed: %v", err) - } + require.NoError(t, err, "Save block failed") bh, err := db.BlockHeaders(ctx, tt.bh.Header.Slot, tt.bh.Header.ProposerIndex) - - if err != nil { - t.Fatal(err) - } - if bh != nil { - t.Errorf("Expected block to have been deleted, received: %v", bh) - } - + require.NoError(t, err) + assert.DeepEqual(t, ([]*ethpb.SignedBeaconBlockHeader)(nil), bh, "Expected block to have been deleted") } - } func TestHasHistoryBlkHdr(t *testing.T) { @@ -152,27 +124,17 @@ func TestHasHistoryBlkHdr(t *testing.T) { }, } for _, tt := range tests { - found := db.HasBlockHeader(ctx, tt.bh.Header.Slot, tt.bh.Header.ProposerIndex) - if found { - t.Fatal("has block header should return false for block headers that are not in db") - } + require.Equal(t, false, found, "Has block header should return false for block headers that are not in db") err := db.SaveBlockHeader(ctx, tt.bh) - if err != nil { - t.Fatalf("save block failed: %v", err) - } + require.NoError(t, err, "Save block failed") } for _, tt := range tests { err := db.SaveBlockHeader(ctx, tt.bh) - if err != nil { - t.Fatalf("save block failed: %v", err) - } + require.NoError(t, err, "Save block failed") found := db.HasBlockHeader(ctx, tt.bh.Header.Slot, tt.bh.Header.ProposerIndex) - - if !found { - t.Fatal("has block header should return true") - } + require.Equal(t, true, found, "Block header should exist") } } @@ -204,39 +166,26 @@ func TestPruneHistoryBlkHdr(t *testing.T) { for _, tt := range tests { err := db.SaveBlockHeader(ctx, tt.bh) - if err != nil { - t.Fatalf("save block header failed: %v", err) - } + require.NoError(t, err, "Save block header failed") bha, err := db.BlockHeaders(ctx, tt.bh.Header.Slot, tt.bh.Header.ProposerIndex) - if err != nil { - t.Fatalf("failed to get block header: %v", err) - } - - if bha == nil || !reflect.DeepEqual(bha[0], tt.bh) { - t.Fatalf("get should return bh: %v", bha) - } + require.NoError(t, err, "Failed to get block header") + require.NotNil(t, bha) + require.DeepEqual(t, tt.bh, bha[0], "Should return bh") } currentEpoch := uint64(3) historyToKeep := uint64(2) err := db.PruneBlockHistory(ctx, currentEpoch, historyToKeep) - if err != nil { - t.Fatalf("failed to prune: %v", err) - } + require.NoError(t, err, "Failed to prune") for _, tt := range tests { bha, err := db.BlockHeaders(ctx, tt.bh.Header.Slot, tt.bh.Header.ProposerIndex) - if err != nil { - t.Fatalf("failed to get block header: %v", err) - } + require.NoError(t, err, "Failed to get block header") if helpers.SlotToEpoch(tt.bh.Header.Slot) >= currentEpoch-historyToKeep { - if bha == nil || !reflect.DeepEqual(bha[0], tt.bh) { - t.Fatalf("get should return bh: %v", bha) - } + require.NotNil(t, bha) + require.DeepEqual(t, tt.bh, bha[0], "Should return bh") } else { - if bha != nil { - t.Fatalf("block header should have been pruned: %v", bha) - } + require.NotNil(t, bha, "Block header should have been pruned") } } } diff --git a/slasher/db/kv/chain_data_test.go b/slasher/db/kv/chain_data_test.go index 62f63b876dee..b5549f6e3eb1 100644 --- a/slasher/db/kv/chain_data_test.go +++ b/slasher/db/kv/chain_data_test.go @@ -5,8 +5,9 @@ import ( "flag" "testing" - "github.com/gogo/protobuf/proto" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" + "github.com/prysmaticlabs/prysm/shared/testutil/require" "github.com/urfave/cli/v2" ) @@ -42,15 +43,10 @@ func TestChainHead(t *testing.T) { } for _, tt := range tests { - if err := db.SaveChainHead(ctx, tt.head); err != nil { - t.Fatal(err) - } + require.NoError(t, db.SaveChainHead(ctx, tt.head)) head, err := db.ChainHead(ctx) - if err != nil { - t.Fatalf("failed to get block: %v", err) - } - if head == nil || !proto.Equal(head, tt.head) { - t.Errorf("Expected %v, got %v", tt.head, head) - } + require.NoError(t, err, "Failed to get block") + assert.NotNil(t, head) + assert.DeepEqual(t, tt.head, head) } } diff --git a/slasher/db/kv/indexed_attestations_test.go b/slasher/db/kv/indexed_attestations_test.go index 0dc70b0838ee..a50528103c76 100644 --- a/slasher/db/kv/indexed_attestations_test.go +++ b/slasher/db/kv/indexed_attestations_test.go @@ -3,10 +3,10 @@ package kv import ( "context" "flag" - "reflect" "testing" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" + "github.com/prysmaticlabs/prysm/shared/testutil/require" "github.com/urfave/cli/v2" ) @@ -68,12 +68,8 @@ func TestHasIndexedAttestation_NilDB(t *testing.T) { ctx := context.Background() hasIdxAtt, err := db.HasIndexedAttestation(ctx, tests[0].idxAtt) - if err != nil { - t.Fatal(err) - } - if hasIdxAtt { - t.Fatal("HasIndexedAttestation should return false") - } + require.NoError(t, err) + require.Equal(t, false, hasIdxAtt) } func TestSaveIndexedAttestation(t *testing.T) { @@ -83,18 +79,11 @@ func TestSaveIndexedAttestation(t *testing.T) { ctx := context.Background() for _, tt := range tests { - if err := db.SaveIndexedAttestation(ctx, tt.idxAtt); err != nil { - t.Fatalf("save indexed attestation failed: %v", err) - } + require.NoError(t, db.SaveIndexedAttestation(ctx, tt.idxAtt), "Save indexed attestation failed") exists, err := db.HasIndexedAttestation(ctx, tt.idxAtt) - if err != nil { - t.Fatalf("failed to get indexed attestation: %v", err) - } - - if !exists { - t.Fatal("Expected to find saved attestation in DB") - } + require.NoError(t, err, "Failed to get indexed attestation") + require.Equal(t, true, exists, "Expected to find saved attestation in DB") } } @@ -327,27 +316,16 @@ func TestIndexedAttestationsWithPrefix(t *testing.T) { db := setupDB(t, cli.NewContext(&app, set, nil)) ctx := context.Background() - if err := db.SaveIndexedAttestations(ctx, tt.attsInDB); err != nil { - t.Fatalf("save indexed attestation failed: %v", err) - } + require.NoError(t, db.SaveIndexedAttestations(ctx, tt.attsInDB), "Save indexed attestation failed") for _, att := range tt.attsInDB { found, err := db.HasIndexedAttestation(ctx, att) - if err != nil { - t.Fatal(err) - } - if !found { - t.Fatalf("Expected to save %v", att) - } + require.NoError(t, err) + require.Equal(t, true, found, "Expected to save %v", att) } idxAtts, err := db.IndexedAttestationsWithPrefix(ctx, tt.targetEpoch, tt.searchPrefix) - if err != nil { - t.Fatalf("failed to get indexed attestation: %v", err) - } - - if !reflect.DeepEqual(tt.expectedResult, idxAtts) { - t.Fatalf("Expected %v, received: %v", tt.expectedResult, idxAtts) - } + require.NoError(t, err, "Failed to get indexed attestation") + require.DeepEqual(t, tt.expectedResult, idxAtts) }) } } @@ -496,27 +474,16 @@ func TestIndexedAttestationsForTarget(t *testing.T) { db := setupDB(t, cli.NewContext(&app, set, nil)) ctx := context.Background() - if err := db.SaveIndexedAttestations(ctx, tt.attsInDB); err != nil { - t.Fatalf("save indexed attestation failed: %v", err) - } + require.NoError(t, db.SaveIndexedAttestations(ctx, tt.attsInDB), "Save indexed attestation failed") for _, att := range tt.attsInDB { found, err := db.HasIndexedAttestation(ctx, att) - if err != nil { - t.Fatal(err) - } - if !found { - t.Fatalf("Expected to save %v", att) - } + require.NoError(t, err) + require.Equal(t, true, found, "Expected to save %v", att) } idxAtts, err := db.IndexedAttestationsForTarget(ctx, tt.targetEpoch) - if err != nil { - t.Fatalf("failed to get indexed attestation: %v", err) - } - - if !reflect.DeepEqual(tt.expectedResult, idxAtts) { - t.Fatalf("Expected %v, received: %v", tt.expectedResult, idxAtts) - } + require.NoError(t, err, "Failed to get indexed attestation: %v", err) + require.DeepEqual(t, tt.expectedResult, idxAtts) }) } } @@ -689,34 +656,22 @@ func TestDeleteIndexedAttestation(t *testing.T) { db := setupDB(t, cli.NewContext(app, set, nil)) ctx := context.Background() - if err := db.SaveIndexedAttestations(ctx, tt.attsInDB); err != nil { - t.Fatalf("save indexed attestation failed: %v", err) - } + require.NoError(t, db.SaveIndexedAttestations(ctx, tt.attsInDB), "Save indexed attestation failed") for _, att := range tt.attsInDB { found, err := db.HasIndexedAttestation(ctx, att) - if err != nil { - t.Fatal(err) - } - if !found { - t.Fatalf("Expected to save %v", att) - } + require.NoError(t, err) + require.Equal(t, true, found, "Expected to save %v", att) } for _, att := range tt.deleteAtts { - if err := db.DeleteIndexedAttestation(ctx, att); err != nil { - t.Fatal(err) - } + require.NoError(t, db.DeleteIndexedAttestation(ctx, att)) } for i, att := range tt.attsInDB { found, err := db.HasIndexedAttestation(ctx, att) - if err != nil { - t.Fatal(err) - } - if found != tt.foundArray[i] { - t.Fatalf("Expected found to be %t: %v", tt.foundArray[i], att) - } + require.NoError(t, err) + require.Equal(t, tt.foundArray[i], found) } }) } @@ -730,26 +685,16 @@ func TestHasIndexedAttestation(t *testing.T) { for _, tt := range tests { exists, err := db.HasIndexedAttestation(ctx, tt.idxAtt) - if err != nil { - t.Fatal(err) - } - if exists { - t.Fatal("has indexed attestation should return false for indexed attestations that are not in db") - } + require.NoError(t, err) + require.Equal(t, false, exists, "Has indexed attestation should return false for indexed attestations that are not in db") - if err := db.SaveIndexedAttestation(ctx, tt.idxAtt); err != nil { - t.Fatalf("save indexed attestation failed: %v", err) - } + require.NoError(t, db.SaveIndexedAttestation(ctx, tt.idxAtt), "Save indexed attestation failed") } for _, tt := range tests { exists, err := db.HasIndexedAttestation(ctx, tt.idxAtt) - if err != nil { - t.Fatal(err) - } - if !exists { - t.Fatal("has indexed attestation should return true") - } + require.NoError(t, err) + require.Equal(t, true, exists) } } @@ -760,39 +705,24 @@ func TestPruneHistoryIndexedAttestation(t *testing.T) { ctx := context.Background() for _, tt := range tests { - if err := db.SaveIndexedAttestation(ctx, tt.idxAtt); err != nil { - t.Fatalf("save indexed attestation failed: %v", err) - } + require.NoError(t, db.SaveIndexedAttestation(ctx, tt.idxAtt), "Save indexed attestation failed") found, err := db.HasIndexedAttestation(ctx, tt.idxAtt) - if err != nil { - t.Fatalf("failed to get indexed attestation: %v", err) - } - - if !found { - t.Fatal("Expected to find attestation in DB") - } + require.NoError(t, err, "Failed to get indexed attestation") + require.Equal(t, true, found, "Expected to find attestation in DB") } currentEpoch := uint64(2) historyToKeep := uint64(1) - if err := db.PruneAttHistory(ctx, currentEpoch, historyToKeep); err != nil { - t.Fatalf("failed to prune: %v", err) - } + require.NoError(t, db.PruneAttHistory(ctx, currentEpoch, historyToKeep), "Failed to prune") for _, tt := range tests { exists, err := db.HasIndexedAttestation(ctx, tt.idxAtt) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if tt.idxAtt.Data.Target.Epoch > currentEpoch-historyToKeep { - if !exists { - t.Fatal("Expected to find attestation newer than prune age in DB") - } + require.Equal(t, true, exists, "Expected to find attestation newer than prune age in DB") } else { - if exists { - t.Fatal("Expected to not find attestation older than prune age in DB") - } + require.Equal(t, false, exists, "Expected to not find attestation older than prune age in DB") } } } diff --git a/slasher/db/kv/kv_test.go b/slasher/db/kv/kv_test.go index bb8098d6ce7e..6058cfc3d65f 100644 --- a/slasher/db/kv/kv_test.go +++ b/slasher/db/kv/kv_test.go @@ -3,61 +3,51 @@ package kv import ( "crypto/rand" "fmt" + "io/ioutil" "math/big" "os" "path" "testing" "github.com/prysmaticlabs/prysm/shared/testutil" + "github.com/prysmaticlabs/prysm/shared/testutil/require" + "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" ) +func TestMain(m *testing.M) { + logrus.SetLevel(logrus.DebugLevel) + logrus.SetOutput(ioutil.Discard) + + os.Exit(m.Run()) +} + func setupDB(t testing.TB, ctx *cli.Context) *Store { randPath, err := rand.Int(rand.Reader, big.NewInt(1000000)) - if err != nil { - t.Fatalf("Could not generate random file path: %v", err) - } + require.NoError(t, err, "Could not generate random file path") p := path.Join(testutil.TempDir(), fmt.Sprintf("/%d", randPath)) - if err := os.RemoveAll(p); err != nil { - t.Fatalf("Failed to remove directory: %v", err) - } + require.NoError(t, os.RemoveAll(p), "Failed to remove directory") cfg := &Config{} db, err := NewKVStore(p, cfg) - if err != nil { - t.Fatalf("Failed to instantiate DB: %v", err) - } + require.NoError(t, err, "Failed to instantiate DB") t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database: %v", err) - } - if err := os.RemoveAll(db.DatabasePath()); err != nil { - t.Fatalf("Failed to remove directory: %v", err) - } + require.NoError(t, db.Close(), "Failed to close database") + require.NoError(t, os.RemoveAll(db.DatabasePath()), "Failed to remove directory") }) return db } func setupDBDiffCacheSize(t testing.TB, cacheSize int) *Store { randPath, err := rand.Int(rand.Reader, big.NewInt(1000000)) - if err != nil { - t.Fatalf("Could not generate random file path: %v", err) - } + require.NoError(t, err, "Could not generate random file path") p := path.Join(testutil.TempDir(), fmt.Sprintf("/%d", randPath)) - if err := os.RemoveAll(p); err != nil { - t.Fatalf("Failed to remove directory: %v", err) - } + require.NoError(t, os.RemoveAll(p), "Failed to remove directory") cfg := &Config{SpanCacheSize: cacheSize} db, err := NewKVStore(p, cfg) - if err != nil { - t.Fatalf("Failed to instantiate DB: %v", err) - } + require.NoError(t, err, "Failed to instantiate DB") t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database: %v", err) - } - if err := os.RemoveAll(db.DatabasePath()); err != nil { - t.Fatalf("Failed to remove directory: %v", err) - } + require.NoError(t, db.Close(), "Failed to close database") + require.NoError(t, os.RemoveAll(db.DatabasePath()), "Failed to remove directory") }) return db } diff --git a/slasher/db/kv/proposer_slashings_test.go b/slasher/db/kv/proposer_slashings_test.go index cdf0d7690576..8a51d04710e1 100644 --- a/slasher/db/kv/proposer_slashings_test.go +++ b/slasher/db/kv/proposer_slashings_test.go @@ -8,6 +8,7 @@ import ( "testing" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" + "github.com/prysmaticlabs/prysm/shared/testutil/require" "github.com/prysmaticlabs/prysm/slasher/db/types" "github.com/urfave/cli/v2" "gopkg.in/d4l3k/messagediff.v1" @@ -21,20 +22,13 @@ func TestStore_ProposerSlashingNilBucket(t *testing.T) { ps := ðpb.ProposerSlashing{Header_1: ðpb.SignedBeaconBlockHeader{Header: ðpb.BeaconBlockHeader{ProposerIndex: 1}}} has, _, err := db.HasProposerSlashing(ctx, ps) - if err != nil { - t.Fatalf("HasProposerSlashing should not return error: %v", err) - } - if has { - t.Fatal("HasProposerSlashing should return false") - } + require.NoError(t, err) + require.Equal(t, false, has) p, err := db.ProposalSlashingsByStatus(ctx, types.SlashingStatus(types.Active)) - if err != nil { - t.Fatalf("Failed to get proposer slashing: %v", err) - } - if p == nil || len(p) != 0 { - t.Fatalf("Get should return empty attester slashing array for a non existent key") - } + require.NoError(t, err, "Failed to get proposer slashing") + require.NotNil(t, p) + require.Equal(t, 0, len(p), "Get should return empty attester slashing array for a non existent key") } func TestStore_SaveProposerSlashing(t *testing.T) { @@ -72,14 +66,10 @@ func TestStore_SaveProposerSlashing(t *testing.T) { for _, tt := range tests { err := db.SaveProposerSlashing(ctx, tt.ss, tt.ps) - if err != nil { - t.Fatalf("Save proposer slashing failed: %v", err) - } + require.NoError(t, err, "Save proposer slashing failed") proposerSlashings, err := db.ProposalSlashingsByStatus(ctx, tt.ss) - if err != nil { - t.Fatalf("Failed to get proposer slashings: %v", err) - } + require.NoError(t, err, "Failed to get proposer slashings") if proposerSlashings == nil || !reflect.DeepEqual(proposerSlashings[0], tt.ps) { diff, _ := messagediff.PrettyDiff(proposerSlashings[0], tt.ps) @@ -87,7 +77,6 @@ func TestStore_SaveProposerSlashing(t *testing.T) { t.Fatalf("Proposer slashing: %v should be part of proposer slashings response: %v", tt.ps, proposerSlashings) } } - } func TestStore_UpdateProposerSlashingStatus(t *testing.T) { @@ -116,37 +105,21 @@ func TestStore_UpdateProposerSlashingStatus(t *testing.T) { for _, tt := range tests { err := db.SaveProposerSlashing(ctx, tt.ss, tt.ps) - if err != nil { - t.Fatalf("Save proposer slashing failed: %v", err) - } + require.NoError(t, err, "Save proposer slashing failed") } for _, tt := range tests { has, st, err := db.HasProposerSlashing(ctx, tt.ps) - if err != nil { - t.Fatalf("Failed to get proposer slashing: %v", err) - } - if !has { - t.Fatalf("Failed to find proposer slashing: %v", tt.ps) - } - if st != tt.ss { - t.Fatalf("Failed to find proposer slashing with the correct status: %v", tt.ps) - } + require.NoError(t, err, "Failed to get proposer slashing") + require.Equal(t, true, has, "Failed to find proposer slashing") + require.Equal(t, tt.ss, st, "Failed to find proposer slashing with the correct status") err = db.SaveProposerSlashing(ctx, types.SlashingStatus(types.Included), tt.ps) has, st, err = db.HasProposerSlashing(ctx, tt.ps) - if err != nil { - t.Fatalf("Failed to get proposer slashing: %v", err) - } - if !has { - t.Fatalf("Failed to find proposer slashing: %v", tt.ps) - } - if st != types.Included { - t.Fatalf("Failed to find proposer slashing with the correct status: %v", tt.ps) - } - + require.NoError(t, err, "Failed to get proposer slashing") + require.Equal(t, true, has, "Failed to find proposer slashing") + require.Equal(t, (types.SlashingStatus)(types.Included), st, "Failed to find proposer slashing with the correct status") } - } func TestStore_SaveProposerSlashings(t *testing.T) { @@ -170,13 +143,9 @@ func TestStore_SaveProposerSlashings(t *testing.T) { }, } err := db.SaveProposerSlashings(ctx, types.Active, ps) - if err != nil { - t.Fatalf("Save proposer slashings failed: %v", err) - } + require.NoError(t, err, "Save proposer slashings failed") proposerSlashings, err := db.ProposalSlashingsByStatus(ctx, types.Active) - if err != nil { - t.Fatalf("Failed to get proposer slashings: %v", err) - } + require.NoError(t, err, "Failed to get proposer slashings") sort.SliceStable(proposerSlashings, func(i, j int) bool { return proposerSlashings[i].Header_1.Header.ProposerIndex < proposerSlashings[j].Header_1.Header.ProposerIndex }) diff --git a/slasher/db/kv/spanner_new_test.go b/slasher/db/kv/spanner_new_test.go index f003e5a32f8d..267dacbad184 100644 --- a/slasher/db/kv/spanner_new_test.go +++ b/slasher/db/kv/spanner_new_test.go @@ -1,13 +1,12 @@ package kv import ( - "bytes" "context" "encoding/hex" "flag" - "reflect" "testing" + "github.com/prysmaticlabs/prysm/shared/testutil/require" dbTypes "github.com/prysmaticlabs/prysm/slasher/db/types" "github.com/prysmaticlabs/prysm/slasher/detection/attestations/types" "github.com/urfave/cli/v2" @@ -62,16 +61,10 @@ func TestValidatorSpans_NilDB(t *testing.T) { validatorIdx := uint64(1) es, err := db.EpochSpans(ctx, validatorIdx, false) - if err != nil { - t.Fatalf("Nil EpochSpansMap should not return error: %v", err) - } + require.NoError(t, err, "Nil EpochSpansMap should not return error") cleanStore, err := types.NewEpochStore([]byte{}) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(es, cleanStore) { - t.Fatal("EpochSpans should return empty byte array if no record exists in the db") - } + require.NoError(t, err) + require.DeepEqual(t, es, cleanStore, "EpochSpans should return empty byte array if no record exists in the db") } func TestStore_SaveReadEpochSpans(t *testing.T) { @@ -83,39 +76,25 @@ func TestStore_SaveReadEpochSpans(t *testing.T) { for _, tt := range spanNewTests { t.Run(tt.name, func(t *testing.T) { spans, err := hex.DecodeString(tt.spansHex) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) es, err := types.NewEpochStore(spans) - if err != tt.err { - t.Fatalf("Failed to get the right error expected: %v got: %v", tt.err, err) - } - if err = db.SaveEpochSpans(ctx, tt.epoch, es, false); err != nil { - t.Fatal(err) + if tt.err != nil { + require.ErrorContains(t, tt.err.Error(), err) + } else { + require.NoError(t, err) } + require.NoError(t, db.SaveEpochSpans(ctx, tt.epoch, es, false)) sm, err := db.EpochSpans(ctx, tt.epoch, false) - if err != nil { - t.Fatalf("Failed to get validator spans: %v", err) - } + require.NoError(t, err, "Failed to get validator spans") spansResult, err := hex.DecodeString(tt.spansResultHex) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) esr, err := types.NewEpochStore(spansResult) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(sm, esr) { - t.Fatalf("Get should return validator spans: %v got: %v", spansResult, sm) - } + require.NoError(t, err) + require.DeepEqual(t, sm, esr, "Get should return validator spans: %v", spansResult) s, err := es.GetValidatorSpan(1) - if err != nil { - t.Fatalf("Failed to get validator 1 span: %v", err) - } - if !reflect.DeepEqual(s, tt.validator1Span) { - t.Fatalf("Get should return validator span for validator 2: %v got: %v", tt.validator1Span, s) - } + require.NoError(t, err, "Failed to get validator 1 span") + require.DeepEqual(t, tt.validator1Span, s, "Get should return validator span for validator 2: %v", tt.validator1Span) }) } } @@ -135,30 +114,18 @@ func TestStore_SaveEpochSpans_ToCache(t *testing.T) { 100: {MinSpan: 49, MaxSpan: 96, SigBytes: [2]byte{11, 98}, HasAttested: true}, } epochStore, err := types.EpochStoreFromMap(spansToSave) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) epoch := uint64(9) - if err := db.SaveEpochSpans(ctx, epoch, epochStore, dbTypes.UseCache); err != nil { - t.Fatal(err) - } + require.NoError(t, db.SaveEpochSpans(ctx, epoch, epochStore, dbTypes.UseCache)) esFromCache, err := db.EpochSpans(ctx, epoch, dbTypes.UseCache) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(epochStore.Bytes(), esFromCache.Bytes()) { - t.Fatalf("Expected store from DB to be %#x, received %#x", epochStore.Bytes(), esFromCache.Bytes()) - } + require.NoError(t, err) + require.DeepEqual(t, epochStore.Bytes(), esFromCache.Bytes()) esFromDB, err := db.EpochSpans(ctx, epoch, dbTypes.UseDB) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(esFromDB.Bytes(), esFromCache.Bytes()) { - t.Fatalf("Expected store asked from DB to use cache, \nreceived %#x, \nexpected %#x", esFromDB.Bytes(), esFromCache.Bytes()) - } + require.NoError(t, err) + require.DeepEqual(t, esFromDB.Bytes(), esFromCache.Bytes()) } func TestStore_SaveEpochSpans_ToDB(t *testing.T) { @@ -176,29 +143,17 @@ func TestStore_SaveEpochSpans_ToDB(t *testing.T) { 100: {MinSpan: 49, MaxSpan: 96, SigBytes: [2]byte{11, 98}, HasAttested: true}, } epochStore, err := types.EpochStoreFromMap(spansToSave) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) epoch := uint64(9) - if err := db.SaveEpochSpans(ctx, epoch, epochStore, dbTypes.UseDB); err != nil { - t.Fatal(err) - } + require.NoError(t, db.SaveEpochSpans(ctx, epoch, epochStore, dbTypes.UseDB)) // Expect cache to retrieve from DB if its not in cache. esFromCache, err := db.EpochSpans(ctx, epoch, dbTypes.UseCache) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(esFromCache.Bytes(), epochStore.Bytes()) { - t.Fatalf("Expected cache request to be %#x, expected %#x", epochStore.Bytes(), esFromCache.Bytes()) - } + require.NoError(t, err) + require.DeepEqual(t, esFromCache.Bytes(), epochStore.Bytes()) esFromDB, err := db.EpochSpans(ctx, epoch, dbTypes.UseDB) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(epochStore.Bytes(), esFromDB.Bytes()) { - t.Fatalf("Expected store from DB to be %#x, received %#x", epochStore.Bytes(), esFromDB.Bytes()) - } + require.NoError(t, err) + require.DeepEqual(t, epochStore.Bytes(), esFromDB.Bytes()) } diff --git a/slasher/db/kv/spanner_test.go b/slasher/db/kv/spanner_test.go index 5c3fbf761a17..bcb3bb9dd729 100644 --- a/slasher/db/kv/spanner_test.go +++ b/slasher/db/kv/spanner_test.go @@ -3,10 +3,11 @@ package kv import ( "context" "flag" - "reflect" "testing" "time" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" + "github.com/prysmaticlabs/prysm/shared/testutil/require" "github.com/prysmaticlabs/prysm/slasher/detection/attestations/types" "github.com/urfave/cli/v2" ) @@ -55,12 +56,8 @@ func TestValidatorSpanMap_NilDB(t *testing.T) { validatorIdx := uint64(1) vsm, _, err := db.EpochSpansMap(ctx, validatorIdx) - if err != nil { - t.Fatalf("Nil EpochSpansMap should not return error: %v", err) - } - if !reflect.DeepEqual(vsm, map[uint64]types.Span{}) { - t.Fatal("EpochSpansMap should return nil") - } + require.NoError(t, err, "Nil EpochSpansMap should not return error") + require.DeepEqual(t, map[uint64]types.Span{}, vsm, "EpochSpansMap should return empty map") } func TestStore_SaveSpans(t *testing.T) { @@ -71,24 +68,14 @@ func TestStore_SaveSpans(t *testing.T) { for _, tt := range spanTests { err := db.SaveEpochSpansMap(ctx, tt.epoch, tt.spanMap) - if err != nil { - t.Fatalf("Save validator span map failed: %v", err) - } + require.NoError(t, err, "Save validator span map failed") sm, _, err := db.EpochSpansMap(ctx, tt.epoch) - if err != nil { - t.Fatalf("Failed to get validator span map: %v", err) - } - - if sm == nil || !reflect.DeepEqual(sm, tt.spanMap) { - t.Fatalf("Get should return validator span map: %v got: %v", tt.spanMap, sm) - } + require.NoError(t, err, "Failed to get validator span map") + require.NotNil(t, sm) + require.DeepEqual(t, tt.spanMap, sm, "Get should return validator span map") s, err := db.EpochSpanByValidatorIndex(ctx, 1, tt.epoch) - if err != nil { - t.Fatalf("Failed to get validator span for epoch 1: %v", err) - } - if !reflect.DeepEqual(s, tt.spanMap[1]) { - t.Fatalf("Get should return validator spans for epoch 1: %v got: %v", tt.spanMap[1], s) - } + require.NoError(t, err, "Failed to get validator span for epoch 1") + require.DeepEqual(t, tt.spanMap[1], s, "Get should return validator spans for epoch 1") } } @@ -100,26 +87,17 @@ func TestStore_SaveCachedSpans(t *testing.T) { for _, tt := range spanTests { err := db.SaveEpochSpansMap(ctx, tt.epoch, tt.spanMap) - if err != nil { - t.Fatalf("Save validator span map failed: %v", err) - } + require.NoError(t, err, "Save validator span map failed") // wait for value to pass through cache buffers time.Sleep(time.Millisecond * 10) sm, _, err := db.EpochSpansMap(ctx, tt.epoch) - if err != nil { - t.Fatalf("Failed to get validator span map: %v", err) - } + require.NoError(t, err, "Failed to get validator span map") + require.NotNil(t, sm) + require.DeepEqual(t, tt.spanMap, sm, "Get should return validator span map") - if sm == nil || !reflect.DeepEqual(sm, tt.spanMap) { - t.Fatalf("Get should return validator span map: %v got: %v", tt.spanMap, sm) - } s, err := db.EpochSpanByValidatorIndex(ctx, 1, tt.epoch) - if err != nil { - t.Fatalf("Failed to get validator span for epoch 1: %v", err) - } - if !reflect.DeepEqual(s, tt.spanMap[1]) { - t.Fatalf("Get should return validator spans for epoch 1: %v got: %v", tt.spanMap[1], s) - } + require.NoError(t, err, "Failed to get validator span for epoch 1") + require.DeepEqual(t, tt.spanMap[1], s, "Get should return validator spans for epoch 1") } } @@ -131,30 +109,19 @@ func TestStore_DeleteEpochSpans(t *testing.T) { db.spanCacheEnabled = false for _, tt := range spanTests { err := db.SaveEpochSpansMap(ctx, tt.epoch, tt.spanMap) - if err != nil { - t.Fatalf("Save validator span map failed: %v", err) - } + require.NoError(t, err, "Save validator span map failed") } for _, tt := range spanTests { sm, _, err := db.EpochSpansMap(ctx, tt.epoch) - if err != nil { - t.Fatalf("Failed to get validator span map: %v", err) - } - if sm == nil || !reflect.DeepEqual(sm, tt.spanMap) { - t.Fatalf("Get should return validator span map: %v got: %v", tt.spanMap, sm) - } + require.NoError(t, err, "Failed to get validator span map") + require.NotNil(t, sm) + require.DeepEqual(t, tt.spanMap, sm, "Get should return validator span map") err = db.DeleteEpochSpans(ctx, tt.epoch) - if err != nil { - t.Fatalf("Delete validator span map error: %v", err) - } + require.NoError(t, err, "Delete validator span map error") sm, _, err = db.EpochSpansMap(ctx, tt.epoch) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(sm, map[uint64]types.Span{}) { - t.Errorf("Expected validator span map to be deleted, received: %v", sm) - } + require.NoError(t, err) + require.DeepEqual(t, map[uint64]types.Span{}, sm, "Expected validator span map to be deleted") } } @@ -166,35 +133,24 @@ func TestValidatorSpanMap_DeletesOnCacheSavesToDB(t *testing.T) { for _, tt := range spanTests { err := db.SaveEpochSpansMap(ctx, tt.epoch, tt.spanMap) - if err != nil { - t.Fatalf("Save validator span map failed: %v", err) - } + require.NoError(t, err, "Save validator span map failed") } // Wait for value to pass through cache buffers. time.Sleep(time.Millisecond * 10) for _, tt := range spanTests { spanMap, _, err := db.EpochSpansMap(ctx, tt.epoch) - if err != nil { - t.Fatalf("Failed to get validator span map: %v", err) - } - if spanMap == nil || !reflect.DeepEqual(spanMap, tt.spanMap) { - t.Fatalf("Get should return validator span map: %v got: %v", tt.spanMap, spanMap) - } + require.NoError(t, err, "Failed to get validator span map") + require.NotNil(t, spanMap) + require.DeepEqual(t, tt.spanMap, spanMap, "Get should return validator span map") - if err = db.DeleteEpochSpans(ctx, tt.epoch); err != nil { - t.Fatalf("Delete validator span map error: %v", err) - } + require.NoError(t, db.DeleteEpochSpans(ctx, tt.epoch), "Delete validator span map error") // Wait for value to pass through cache buffers. db.EnableSpanCache(false) time.Sleep(time.Millisecond * 10) spanMap, _, err = db.EpochSpansMap(ctx, tt.epoch) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) db.EnableSpanCache(true) - if !reflect.DeepEqual(spanMap, tt.spanMap) { - t.Errorf("Expected validator span map to be deleted, received: %v", spanMap) - } + require.DeepEqual(t, tt.spanMap, spanMap, "Expected validator span map to be deleted") } } @@ -212,21 +168,16 @@ func TestValidatorSpanMap_SaveOnEvict(t *testing.T) { } for i := uint64(0); i < 6; i++ { err := db.SaveEpochSpansMap(ctx, i, tsm.spanMap) - if err != nil { - t.Fatalf("Save validator span map failed: %v", err) - } + require.NoError(t, err, "Save validator span map failed") } // Wait for value to pass through cache buffers. time.Sleep(time.Millisecond * 1000) for i := uint64(0); i < 6; i++ { sm, _, err := db.EpochSpansMap(ctx, i) - if err != nil { - t.Fatalf("Failed to get validator span map: %v", err) - } - if sm == nil || !reflect.DeepEqual(sm, tsm.spanMap) { - t.Fatalf("Get should return validator: %d span map: %v got: %v", i, tsm.spanMap, sm) - } + require.NoError(t, err, "Failed to get validator span map") + require.NotNil(t, sm) + require.DeepEqual(t, tsm.spanMap, sm, "Get should return validator") } } @@ -238,24 +189,16 @@ func TestValidatorSpanMap_SaveCachedSpansMaps(t *testing.T) { for _, tt := range spanTests { err := db.SaveEpochSpansMap(ctx, tt.epoch, tt.spanMap) - if err != nil { - t.Fatalf("Save validator span map failed: %v", err) - } + require.NoError(t, err, "Save validator span map failed") } // wait for value to pass through cache buffers time.Sleep(time.Millisecond * 10) - if err := db.SaveCachedSpansMaps(ctx); err != nil { - t.Errorf("Failed to save cached span maps to db: %v", err) - } + require.NoError(t, db.SaveCachedSpansMaps(ctx), "Failed to save cached span maps to db") db.spanCache.Purge() for _, tt := range spanTests { sm, _, err := db.EpochSpansMap(ctx, tt.epoch) - if err != nil { - t.Fatalf("Failed to get validator span map: %v", err) - } - if !reflect.DeepEqual(sm, tt.spanMap) { - t.Fatalf("Get should return validator span map: %v got: %v", tt.spanMap, sm) - } + require.NoError(t, err, "Failed to get validator span map") + require.DeepEqual(t, tt.spanMap, sm, "Get should return validator span map") } } @@ -268,36 +211,20 @@ func TestStore_ReadWriteEpochsSpanByValidatorsIndices(t *testing.T) { for _, tt := range spanTests { err := db.SaveEpochSpansMap(ctx, tt.epoch, tt.spanMap) - if err != nil { - t.Fatalf("Save validator span map failed: %v", err) - } + require.NoError(t, err, "Save validator span map failed") } res, err := db.EpochsSpanByValidatorsIndices(ctx, []uint64{1, 2, 3}, 3) - if err != nil { - t.Fatal(err) - } - if len(res) != len(spanTests) { - t.Errorf("Wanted map of %d elemets, received map of %d elements", len(spanTests), len(res)) - } + require.NoError(t, err) + assert.Equal(t, len(spanTests), len(res), "Unexpected number of elements") for _, tt := range spanTests { - if !reflect.DeepEqual(res[tt.epoch], tt.spanMap) { - t.Errorf("Wanted span map to be equal to: %v , received span map: %v ", spanTests[0].spanMap, res[1]) - } + assert.DeepEqual(t, tt.spanMap, res[tt.epoch], "Unexpected span map") } db1 := setupDB(t, cli.NewContext(&app, set, nil)) - if err := db1.SaveEpochsSpanByValidatorsIndices(ctx, res); err != nil { - t.Fatal(err) - } + require.NoError(t, db1.SaveEpochsSpanByValidatorsIndices(ctx, res)) res, err = db1.EpochsSpanByValidatorsIndices(ctx, []uint64{1, 2, 3}, 3) - if err != nil { - t.Fatal(err) - } - if len(res) != len(spanTests) { - t.Errorf("Wanted map of %d elemets, received map of %d elements", len(spanTests), len(res)) - } + require.NoError(t, err) + assert.Equal(t, len(spanTests), len(res), "Unexpected number of elements") for _, tt := range spanTests { - if !reflect.DeepEqual(res[tt.epoch], tt.spanMap) { - t.Errorf("Wanted span map to be equal to: %v , received span map: %v ", spanTests[0].spanMap, res[1]) - } + assert.DeepEqual(t, tt.spanMap, res[tt.epoch], "Unexpected span map") } } diff --git a/slasher/db/kv/validator_id_pubkey_test.go b/slasher/db/kv/validator_id_pubkey_test.go index e98052b12861..c7a2bac96cc0 100644 --- a/slasher/db/kv/validator_id_pubkey_test.go +++ b/slasher/db/kv/validator_id_pubkey_test.go @@ -1,11 +1,11 @@ package kv import ( - "bytes" "context" "flag" "testing" + "github.com/prysmaticlabs/prysm/shared/testutil/require" "github.com/urfave/cli/v2" ) @@ -42,13 +42,8 @@ func TestNilDBValidatorPublicKey(t *testing.T) { validatorID := uint64(1) pk, err := db.ValidatorPubKey(ctx, validatorID) - if err != nil { - t.Fatal("nil ValidatorPubKey should not return error") - } - if pk != nil { - t.Fatal("ValidatorPubKey should return nil") - } - + require.NoError(t, err, "Nil ValidatorPubKey should not return error") + require.DeepEqual(t, ([]uint8)(nil), pk) } func TestSavePubKey(t *testing.T) { @@ -59,20 +54,13 @@ func TestSavePubKey(t *testing.T) { for _, tt := range pkTests { err := db.SavePubKey(ctx, tt.validatorID, tt.pk) - if err != nil { - t.Fatalf("save validator public key failed: %v", err) - } + require.NoError(t, err, "Save validator public key failed") pk, err := db.ValidatorPubKey(ctx, tt.validatorID) - if err != nil { - t.Fatalf("failed to get validator public key: %v", err) - } - - if pk == nil || !bytes.Equal(pk, tt.pk) { - t.Fatalf("get should return validator public key: %v", tt.pk) - } + require.NoError(t, err, "Failed to get validator public key") + require.NotNil(t, pk) + require.DeepEqual(t, tt.pk, pk, "Should return validator public key") } - } func TestDeletePublicKey(t *testing.T) { @@ -82,34 +70,19 @@ func TestDeletePublicKey(t *testing.T) { ctx := context.Background() for _, tt := range pkTests { - - err := db.SavePubKey(ctx, tt.validatorID, tt.pk) - if err != nil { - t.Fatalf("save validator public key failed: %v", err) - } + require.NoError(t, db.SavePubKey(ctx, tt.validatorID, tt.pk), "Save validator public key failed") } for _, tt := range pkTests { pk, err := db.ValidatorPubKey(ctx, tt.validatorID) - if err != nil { - t.Fatalf("failed to get validator public key: %v", err) - } + require.NoError(t, err, "Failed to get validator public key") + require.NotNil(t, pk) + require.DeepEqual(t, tt.pk, pk, "Should return validator public key") - if pk == nil || !bytes.Equal(pk, tt.pk) { - t.Fatalf("get should return validator public key: %v", pk) - } err = db.DeletePubKey(ctx, tt.validatorID) - if err != nil { - t.Fatalf("delete validator public key: %v", err) - } + require.NoError(t, err, "Delete validator public key") pk, err = db.ValidatorPubKey(ctx, tt.validatorID) - if err != nil { - t.Fatal(err) - } - if pk != nil { - t.Errorf("Expected validator public key to be deleted, received: %v", pk) - } - + require.NoError(t, err) + require.DeepEqual(t, []byte(nil), pk, "Expected validator public key to be deleted") } - } diff --git a/slasher/db/testing/BUILD.bazel b/slasher/db/testing/BUILD.bazel index 82a913869898..85d5c92b1947 100644 --- a/slasher/db/testing/BUILD.bazel +++ b/slasher/db/testing/BUILD.bazel @@ -20,6 +20,7 @@ go_test( embed = [":go_default_library"], deps = [ "//shared/testutil:go_default_library", + "//shared/testutil/require:go_default_library", "//slasher/db:go_default_library", "//slasher/db/kv:go_default_library", ], diff --git a/slasher/db/testing/setup_db_test.go b/slasher/db/testing/setup_db_test.go index c0b263fe8c46..f020803091cf 100644 --- a/slasher/db/testing/setup_db_test.go +++ b/slasher/db/testing/setup_db_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/prysmaticlabs/prysm/shared/testutil" + "github.com/prysmaticlabs/prysm/shared/testutil/require" slasherDB "github.com/prysmaticlabs/prysm/slasher/db" "github.com/prysmaticlabs/prysm/slasher/db/kv" ) @@ -16,24 +17,14 @@ import ( func TestClearDB(t *testing.T) { // Setting up manually is required, since SetupDB() will also register a teardown procedure. randPath, err := rand.Int(rand.Reader, big.NewInt(1000000)) - if err != nil { - t.Fatalf("Could not generate random file path: %v", err) - } + require.NoError(t, err, "Could not generate random file path") p := path.Join(testutil.TempDir(), fmt.Sprintf("/%d", randPath)) - if err := os.RemoveAll(p); err != nil { - t.Fatalf("Failed to remove directory: %v", err) - } + require.NoError(t, os.RemoveAll(p), "Failed to remove directory") cfg := &kv.Config{} db, err := slasherDB.NewDB(p, cfg) - if err != nil { - t.Fatalf("Failed to instantiate DB: %v", err) - } + require.NoError(t, err, "Failed to instantiate DB") db.EnableSpanCache(false) - if err := db.ClearDB(); err != nil { - t.Fatal(err) - } - - if _, err := os.Stat(db.DatabasePath()); !os.IsNotExist(err) { - t.Fatalf("db wasnt cleared %v", err) - } + require.NoError(t, db.ClearDB()) + _, err = os.Stat(db.DatabasePath()) + require.Equal(t, true, os.IsNotExist(err), "Db wasnt cleared %v", err) } diff --git a/slasher/detection/BUILD.bazel b/slasher/detection/BUILD.bazel index 0ac9c6b36c7c..b38e0a6a8d08 100644 --- a/slasher/detection/BUILD.bazel +++ b/slasher/detection/BUILD.bazel @@ -45,6 +45,7 @@ go_test( deps = [ "//shared/bytesutil:go_default_library", "//shared/event:go_default_library", + "//shared/testutil/assert:go_default_library", "//shared/testutil/require:go_default_library", "//slasher/db/testing:go_default_library", "//slasher/db/types:go_default_library", diff --git a/slasher/detection/attestations/BUILD.bazel b/slasher/detection/attestations/BUILD.bazel index b7ff92c98299..5a8edd0b89dd 100644 --- a/slasher/detection/attestations/BUILD.bazel +++ b/slasher/detection/attestations/BUILD.bazel @@ -26,14 +26,20 @@ go_library( go_test( name = "go_default_test", - srcs = ["spanner_test.go"], + srcs = [ + "attestations_test.go", + "spanner_test.go", + ], embed = [":go_default_library"], deps = [ "//shared/featureconfig:go_default_library", "//shared/sliceutil:go_default_library", + "//shared/testutil/assert:go_default_library", + "//shared/testutil/require:go_default_library", "//slasher/db/testing:go_default_library", "//slasher/db/types:go_default_library", "//slasher/detection/attestations/types:go_default_library", "@com_github_prysmaticlabs_ethereumapis//eth/v1alpha1:go_default_library", + "@com_github_sirupsen_logrus//:go_default_library", ], ) diff --git a/slasher/detection/attestations/attestations_test.go b/slasher/detection/attestations/attestations_test.go new file mode 100644 index 000000000000..3b14cda82d0e --- /dev/null +++ b/slasher/detection/attestations/attestations_test.go @@ -0,0 +1,16 @@ +package attestations + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/sirupsen/logrus" +) + +func TestMain(m *testing.M) { + logrus.SetLevel(logrus.DebugLevel) + logrus.SetOutput(ioutil.Discard) + + os.Exit(m.Run()) +} diff --git a/slasher/detection/attestations/spanner_test.go b/slasher/detection/attestations/spanner_test.go index aa80a3f35e78..159811acf245 100644 --- a/slasher/detection/attestations/spanner_test.go +++ b/slasher/detection/attestations/spanner_test.go @@ -8,6 +8,8 @@ import ( ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" "github.com/prysmaticlabs/prysm/shared/featureconfig" "github.com/prysmaticlabs/prysm/shared/sliceutil" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" + "github.com/prysmaticlabs/prysm/shared/testutil/require" testDB "github.com/prysmaticlabs/prysm/slasher/db/testing" dbTypes "github.com/prysmaticlabs/prysm/slasher/db/types" "github.com/prysmaticlabs/prysm/slasher/detection/attestations/types" @@ -35,7 +37,7 @@ func TestSpanDetector_DetectSlashingsForAttestation_Double(t *testing.T) { name string att *ethpb.IndexedAttestation incomingAtt *ethpb.IndexedAttestation - slashCount uint64 + slashCount int } tests := []testStruct{ { @@ -243,14 +245,10 @@ func TestSpanDetector_DetectSlashingsForAttestation_Double(t *testing.T) { slasherDB: db, } - if err := sd.UpdateSpans(ctx, tt.att); err != nil { - t.Fatal(err) - } + require.NoError(t, sd.UpdateSpans(ctx, tt.att)) res, err := sd.DetectSlashingsForAttestation(ctx, tt.incomingAtt) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) var want []*types.DetectionResult if tt.slashCount > 0 { @@ -263,12 +261,8 @@ func TestSpanDetector_DetectSlashingsForAttestation_Double(t *testing.T) { }) } } - if !reflect.DeepEqual(res, want) { - t.Errorf("Wanted: %v, received %v", want, res) - } - if uint64(len(res)) != tt.slashCount { - t.Fatalf("Unexpected amount of slashings found, received %d, expected %d", len(res), tt.slashCount) - } + assert.DeepEqual(t, want, res) + require.Equal(t, tt.slashCount, len(res), "Unexpected amount of slashings found") }) } } @@ -475,20 +469,14 @@ func TestSpanDetector_DetectSlashingsForAttestation_Surround(t *testing.T) { validatorIndex := uint64(0) for k, v := range tt.spansByEpochForValidator { epochStore, err := types.NewEpochStore([]byte{}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) span := types.Span{ MinSpan: v[0], MaxSpan: v[1], } epochStore, err = epochStore.SetValidatorSpan(validatorIndex, span) - if err != nil { - t.Fatal(err) - } - if err := sd.slasherDB.SaveEpochSpans(ctx, k, epochStore, dbTypes.UseDB); err != nil { - t.Fatalf("Failed to save to slasherDB: %v", err) - } + require.NoError(t, err) + require.NoError(t, sd.slasherDB.SaveEpochSpans(ctx, k, epochStore, dbTypes.UseDB), "Failed to save to slasherDB") } att := ðpb.IndexedAttestation{ @@ -503,12 +491,8 @@ func TestSpanDetector_DetectSlashingsForAttestation_Surround(t *testing.T) { AttestingIndices: []uint64{0}, } res, err := sd.DetectSlashingsForAttestation(ctx, att) - if err != nil { - t.Fatal(err) - } - if !tt.shouldSlash && res != nil { - t.Fatalf("Did not want validator to be slashed but found slashable offense: %v", res) - } + require.NoError(t, err) + require.Equal(t, false, !tt.shouldSlash && res != nil, "Did not want validator to be slashed but found slashable offense: %v", res) if tt.shouldSlash { want := []*types.DetectionResult{ { @@ -516,9 +500,7 @@ func TestSpanDetector_DetectSlashingsForAttestation_Surround(t *testing.T) { SlashableEpoch: tt.slashableEpoch, }, } - if !reflect.DeepEqual(res, want) { - t.Errorf("Wanted: %v, received %v", want, res) - } + assert.DeepEqual(t, want, res) } }) } @@ -647,28 +629,20 @@ func TestSpanDetector_DetectSlashingsForAttestation_MultipleValidators(t *testin db := testDB.SetupSlasherDB(t, false) ctx := context.Background() defer func() { - if err := db.ClearDB(); err != nil { - t.Log(err) - } + assert.NoError(t, db.ClearDB()) }() defer func() { - if err := db.Close(); err != nil { - t.Log(err) - } + assert.NoError(t, db.Close()) }() spanDetector := &SpanDetector{ slasherDB: db, } for _, att := range tt.atts { - if err := spanDetector.UpdateSpans(ctx, att); err != nil { - t.Fatalf("Failed to save to slasherDB: %v", err) - } + require.NoError(t, spanDetector.UpdateSpans(ctx, att), "Failed to save to slasherDB") } res, err := spanDetector.DetectSlashingsForAttestation(ctx, tt.incomingAtt) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) var want []*types.DetectionResult for i := 0; i < len(tt.incomingAtt.AttestingIndices); i++ { if tt.shouldSlash[i] { @@ -814,34 +788,22 @@ func TestNewSpanDetector_UpdateSpans(t *testing.T) { db := testDB.SetupSlasherDB(t, false) ctx := context.Background() defer func() { - if err := db.ClearDB(); err != nil { - t.Log(err) - } + assert.NoError(t, db.ClearDB()) }() defer func() { - if err := db.Close(); err != nil { - t.Log(err) - } + assert.NoError(t, db.Close()) }() sd := &SpanDetector{ slasherDB: db, } - if err := sd.UpdateSpans(ctx, tt.att); err != nil { - t.Fatal(err) - } + require.NoError(t, sd.UpdateSpans(ctx, tt.att)) for epoch := range tt.want { sm, err := sd.slasherDB.EpochSpans(ctx, uint64(epoch), dbTypes.UseDB) - if err != nil { - t.Fatalf("Failed to read from slasherDB: %v", err) - } + require.NoError(t, err, "Failed to read from slasherDB") resMap, err := sm.ToMap() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(resMap, tt.want[epoch]) { - t.Errorf("Wanted and received:\n%v \n%v", tt.want[epoch], resMap) - } + require.NoError(t, err) + assert.DeepEqual(t, tt.want[epoch], resMap) } }) } @@ -868,25 +830,15 @@ func TestSpanDetector_UpdateMinSpansCheckCacheSize(t *testing.T) { db := testDB.SetupSlasherDB(t, false) ctx := context.Background() defer func() { - if err := db.ClearDB(); err != nil { - t.Log(err) - } + assert.NoError(t, db.ClearDB()) }() defer func() { - if err := db.Close(); err != nil { - t.Log(err) - } + assert.NoError(t, db.Close()) }() sd := &SpanDetector{ slasherDB: db, } - if err := sd.updateMinSpan(ctx, att); err != nil { - t.Fatal(err) - } - - if len := db.CacheLength(ctx); len != epochLookback { - t.Fatalf("Expected cache length to be equal to epochLookback: %d got: %d", epochLookback, len) - } - + require.NoError(t, sd.updateMinSpan(ctx, att)) + require.Equal(t, epochLookback, db.CacheLength(ctx), "Unexpected cache length") } diff --git a/slasher/detection/attestations/types/BUILD.bazel b/slasher/detection/attestations/types/BUILD.bazel index ca7e0fb6feba..abb21575b835 100644 --- a/slasher/detection/attestations/types/BUILD.bazel +++ b/slasher/detection/attestations/types/BUILD.bazel @@ -20,6 +20,8 @@ go_test( srcs = ["epoch_store_test.go"], embed = [":go_default_library"], deps = [ + "//shared/testutil/assert:go_default_library", + "//shared/testutil/require:go_default_library", "//slasher/db/testing:go_default_library", "//slasher/db/types:go_default_library", ], diff --git a/slasher/detection/attestations/types/epoch_store_test.go b/slasher/detection/attestations/types/epoch_store_test.go index b27c0263d8a5..44eed42e8ed2 100644 --- a/slasher/detection/attestations/types/epoch_store_test.go +++ b/slasher/detection/attestations/types/epoch_store_test.go @@ -4,9 +4,10 @@ import ( "context" "encoding/hex" "fmt" - "reflect" "testing" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" + "github.com/prysmaticlabs/prysm/shared/testutil/require" testDB "github.com/prysmaticlabs/prysm/slasher/db/testing" dbTypes "github.com/prysmaticlabs/prysm/slasher/db/types" "github.com/prysmaticlabs/prysm/slasher/detection/attestations/types" @@ -54,24 +55,18 @@ func TestEpochStore_GetValidatorSpan_Format(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { decodedHex, err := hex.DecodeString(tt.hexToDecode) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) es, err := types.NewEpochStore(decodedHex) - if err != tt.expectedErr { - t.Fatalf("expected error %v, received %v", tt.expectedErr, err) - } if tt.expectedErr != nil { + require.ErrorContains(t, tt.expectedErr.Error(), err) return + } else { + require.NoError(t, err) } span0, err := es.GetValidatorSpan(0) - if !reflect.DeepEqual(span0, tt.expectedSpan[0]) { - t.Errorf("Expected span to be: %v, received: %v", tt.expectedSpan[0], span0) - } + assert.DeepEqual(t, tt.expectedSpan[0], span0, "Unexpected span") span1, err := es.GetValidatorSpan(1) - if !reflect.DeepEqual(span1, tt.expectedSpan[1]) { - t.Errorf("Expected span to be: %v, received: %v", tt.expectedSpan[1], span1) - } + assert.DeepEqual(t, tt.expectedSpan[1], span1, "Unexpected span") }) } } @@ -121,20 +116,12 @@ func TestEpochStore_GetValidatorSpan_Matches(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { es, err := types.EpochStoreFromMap(tt.spanMap) - if err != nil { - t.Fatal(err) - } - if es.HighestObservedIdx() != tt.highestIdx { - t.Fatalf("Expected highest index %d, received %d", tt.highestIdx, es.HighestObservedIdx()) - } + require.NoError(t, err) + require.Equal(t, tt.highestIdx, es.HighestObservedIdx(), "Unexpected highest index") for k, v := range tt.spanMap { span, err := es.GetValidatorSpan(k) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(span, v) { - t.Fatalf("Expected span %v, received %v", v, span) - } + require.NoError(t, err) + require.DeepEqual(t, v, span, "Unexpected span") } }) } @@ -201,36 +188,24 @@ func TestEpochStore_SetValidatorSpan(t *testing.T) { }, } es, err := types.NewEpochStore([]byte{}) - if err != nil { - t.Fatal(err) - } - if es.HighestObservedIdx() != 0 { - t.Fatalf("Expected highest index to be 0, received %d", es.HighestObservedIdx()) - } + require.NoError(t, err) + require.Equal(t, uint64(0), es.HighestObservedIdx(), "Expected highest index to be 0") lastIdx := uint64(0) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { for k, v := range tt.spanMapToAdd { es, err = es.SetValidatorSpan(k, v) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if k > lastIdx { lastIdx = k } } for k, v := range tt.resultMap { span, err := es.GetValidatorSpan(k) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(v, span) { - t.Errorf("Expected %v, received %v", v, span) - } - } - if es.HighestObservedIdx() != lastIdx { - t.Fatalf("Expected highest index to be %d, received %d", lastIdx, es.HighestObservedIdx()) + require.NoError(t, err) + assert.DeepEqual(t, v, span, "Unexpected span") } + require.Equal(t, lastIdx, es.HighestObservedIdx(), "Unexpected highest index") }) } } @@ -246,9 +221,7 @@ func BenchmarkEpochStore_Save(b *testing.B) { b.ReportAllocs() b.N = 5 for i := 0; i < b.N; i++ { - if err := db.SaveEpochSpansMap(context.Background(), 0, spansMap); err != nil { - b.Fatal(err) - } + require.NoError(b, db.SaveEpochSpansMap(context.Background(), 0, spansMap)) } }) @@ -257,18 +230,14 @@ func BenchmarkEpochStore_Save(b *testing.B) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - if err := db.SaveEpochSpans(context.Background(), 1, store, dbTypes.UseDB); err != nil { - b.Fatal(err) - } + require.NoError(b, db.SaveEpochSpans(context.Background(), 1, store, dbTypes.UseDB)) } }) } func generateEpochStore(t testing.TB, n uint64) (*types.EpochStore, map[uint64]types.Span) { epochStore, err := types.NewEpochStore([]byte{}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) spanMap := make(map[uint64]types.Span) for i := uint64(0); i < n; i++ { span := types.Span{ @@ -279,9 +248,7 @@ func generateEpochStore(t testing.TB, n uint64) (*types.EpochStore, map[uint64]t } spanMap[i] = span epochStore, err = epochStore.SetValidatorSpan(i, span) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) } return epochStore, spanMap } diff --git a/slasher/detection/detect_test.go b/slasher/detection/detect_test.go index 4dade06a0016..44322a3dcd28 100644 --- a/slasher/detection/detect_test.go +++ b/slasher/detection/detect_test.go @@ -7,6 +7,8 @@ import ( ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" "github.com/prysmaticlabs/prysm/shared/bytesutil" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" + "github.com/prysmaticlabs/prysm/shared/testutil/require" testDB "github.com/prysmaticlabs/prysm/slasher/db/testing" status "github.com/prysmaticlabs/prysm/slasher/db/types" "github.com/prysmaticlabs/prysm/slasher/detection/attestations" @@ -154,26 +156,16 @@ func TestDetect_detectAttesterSlashings_Surround(t *testing.T) { slasherDB: db, minMaxSpanDetector: attestations.NewSpanDetector(db), } - if err := db.SaveIndexedAttestations(ctx, tt.savedAtts); err != nil { - t.Fatal(err) - } + require.NoError(t, db.SaveIndexedAttestations(ctx, tt.savedAtts)) for _, att := range tt.savedAtts { - if err := ds.minMaxSpanDetector.UpdateSpans(ctx, att); err != nil { - t.Fatal(err) - } + require.NoError(t, ds.minMaxSpanDetector.UpdateSpans(ctx, att)) } slashings, err := ds.DetectAttesterSlashings(ctx, tt.incomingAtt) - if err != nil { - t.Fatal(err) - } - if len(slashings) != tt.slashingsFound { - t.Fatalf("Unexpected amount of slashings found, received %d, expected %d", len(slashings), tt.slashingsFound) - } + require.NoError(t, err) + require.Equal(t, tt.slashingsFound, len(slashings), "Unexpected amount of slashings found") attsl, err := db.AttesterSlashings(ctx, status.Active) - if len(attsl) != tt.slashingsFound { - t.Fatalf("Didnt save slashing to db") - } + require.Equal(t, tt.slashingsFound, len(attsl), "Didnt save slashing to db") for _, ss := range slashings { slashingAtt1 := ss.Attestation_1 slashingAtt2 := ss.Attestation_2 @@ -307,26 +299,16 @@ func TestDetect_detectAttesterSlashings_Double(t *testing.T) { slasherDB: db, minMaxSpanDetector: attestations.NewSpanDetector(db), } - if err := db.SaveIndexedAttestations(ctx, tt.savedAtts); err != nil { - t.Fatal(err) - } + require.NoError(t, db.SaveIndexedAttestations(ctx, tt.savedAtts)) for _, att := range tt.savedAtts { - if err := ds.minMaxSpanDetector.UpdateSpans(ctx, att); err != nil { - t.Fatal(err) - } + require.NoError(t, ds.minMaxSpanDetector.UpdateSpans(ctx, att)) } slashings, err := ds.DetectAttesterSlashings(ctx, tt.incomingAtt) - if err != nil { - t.Fatal(err) - } - if len(slashings) != tt.slashingsFound { - t.Fatalf("Unexpected amount of slashings found, received %d, expected %d", len(slashings), tt.slashingsFound) - } + require.NoError(t, err) + require.Equal(t, tt.slashingsFound, len(slashings), "Unexpected amount of slashings found") savedSlashings, err := db.AttesterSlashings(ctx, status.Active) - if len(savedSlashings) != tt.slashingsFound { - t.Fatalf("Did not save slashing to db") - } + require.Equal(t, tt.slashingsFound, len(savedSlashings), "Did not save slashing to db") for _, ss := range slashings { slashingAtt1 := ss.Attestation_1 @@ -339,7 +321,6 @@ func TestDetect_detectAttesterSlashings_Double(t *testing.T) { ) } } - }) } } @@ -352,17 +333,11 @@ func TestDetect_detectProposerSlashing(t *testing.T) { slashing *ethpb.ProposerSlashing } sigBlk1slot0, err := testDetect.SignedBlockHeader(testDetect.StartSlot(0), 0) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) sigBlk2slot0, err := testDetect.SignedBlockHeader(testDetect.StartSlot(0), 0) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) sigBlk1epoch1, err := testDetect.SignedBlockHeader(testDetect.StartSlot(1), 0) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) tests := []testStruct{ { name: "same block sig dont slash", @@ -392,20 +367,14 @@ func TestDetect_detectProposerSlashing(t *testing.T) { slasherDB: db, proposalsDetector: proposals.NewProposeDetector(db), } - if err := db.SaveBlockHeader(ctx, tt.blk); err != nil { - t.Fatal(err) - } + require.NoError(t, db.SaveBlockHeader(ctx, tt.blk)) slashing, err := ds.proposalsDetector.DetectDoublePropose(ctx, tt.incomingBlk) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(slashing, tt.slashing) { - t.Errorf("Wanted: %v, received %v", tt.slashing, slashing) - } + require.NoError(t, err) + assert.DeepEqual(t, tt.slashing, slashing) savedSlashings, err := db.ProposalSlashingsByStatus(ctx, status.Active) - if tt.slashing != nil && len(savedSlashings) != 1 { - t.Fatalf("Did not save slashing to db") + if tt.slashing != nil { + require.Equal(t, 1, len(savedSlashings), "Did not save slashing to db") } if slashing != nil && !isDoublePropose(slashing.Header_1, slashing.Header_2) { @@ -415,7 +384,6 @@ func TestDetect_detectProposerSlashing(t *testing.T) { slashing.Header_2, ) } - }) } } @@ -427,28 +395,18 @@ func TestDetect_detectProposerSlashingNoUpdate(t *testing.T) { slashable bool } sigBlk1slot0, err := testDetect.SignedBlockHeader(testDetect.StartSlot(0), 0) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) blk1slot0, err := testDetect.BlockHeader(testDetect.StartSlot(0), 0) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) blk2slot0, err := testDetect.BlockHeader(testDetect.StartSlot(0), 0) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) diffRoot := [32]byte{1, 1, 1} blk2slot0.ParentRoot = diffRoot[:] blk3slot0, err := testDetect.BlockHeader(testDetect.StartSlot(0), 0) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) blk3slot0.StateRoot = diffRoot[:] blk4slot0, err := testDetect.BlockHeader(testDetect.StartSlot(0), 0) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) blk4slot0.BodyRoot = diffRoot[:] tests := []testStruct{ { @@ -485,17 +443,11 @@ func TestDetect_detectProposerSlashingNoUpdate(t *testing.T) { slasherDB: db, proposalsDetector: proposals.NewProposeDetector(db), } - if err := db.SaveBlockHeader(ctx, tt.blk); err != nil { - t.Fatal(err) - } + require.NoError(t, db.SaveBlockHeader(ctx, tt.blk)) slashble, err := ds.proposalsDetector.DetectDoubleProposeNoUpdate(ctx, tt.noUpdtaeBlk) - if err != nil { - t.Fatal(err) - } - if slashble != tt.slashable { - t.Errorf("Wanted slashbale: %v, received slashable: %v", tt.slashable, slashble) - } + require.NoError(t, err) + assert.Equal(t, tt.slashable, slashble) }) } } @@ -574,15 +526,11 @@ func TestServer_MapResultsToAtts(t *testing.T) { }, } for _, atts := range expectedResultsToAtts { - if err := ds.slasherDB.SaveIndexedAttestations(ctx, atts); err != nil { - t.Fatal(err) - } + require.NoError(t, ds.slasherDB.SaveIndexedAttestations(ctx, atts)) } resultsToAtts, err := ds.mapResultsToAtts(ctx, results) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if !reflect.DeepEqual(expectedResultsToAtts, resultsToAtts) { t.Error("Expected map:") for key, value := range resultsToAtts { diff --git a/slasher/detection/proposals/BUILD.bazel b/slasher/detection/proposals/BUILD.bazel index 7be8fb730b45..7204f38d6062 100644 --- a/slasher/detection/proposals/BUILD.bazel +++ b/slasher/detection/proposals/BUILD.bazel @@ -16,12 +16,18 @@ go_library( go_test( name = "go_default_test", - srcs = ["detector_test.go"], + srcs = [ + "detector_test.go", + "proposals_test.go", + ], embed = [":go_default_library"], deps = [ + "//shared/testutil/assert:go_default_library", + "//shared/testutil/require:go_default_library", "//slasher/db/testing:go_default_library", "//slasher/detection/proposals/iface:go_default_library", "//slasher/detection/testing:go_default_library", "@com_github_prysmaticlabs_ethereumapis//eth/v1alpha1:go_default_library", + "@com_github_sirupsen_logrus//:go_default_library", ], ) diff --git a/slasher/detection/proposals/detector_test.go b/slasher/detection/proposals/detector_test.go index 4ff309329763..ea2725545961 100644 --- a/slasher/detection/proposals/detector_test.go +++ b/slasher/detection/proposals/detector_test.go @@ -2,10 +2,11 @@ package proposals import ( "context" - "reflect" "testing" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" + "github.com/prysmaticlabs/prysm/shared/testutil/require" testDB "github.com/prysmaticlabs/prysm/slasher/db/testing" "github.com/prysmaticlabs/prysm/slasher/detection/proposals/iface" testDetect "github.com/prysmaticlabs/prysm/slasher/detection/testing" @@ -21,21 +22,13 @@ func TestProposalsDetector_DetectSlashingsForBlockHeaders(t *testing.T) { slashing *ethpb.ProposerSlashing } blk1slot0, err := testDetect.SignedBlockHeader(testDetect.StartSlot(0), 0) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) blk2slot0, err := testDetect.SignedBlockHeader(testDetect.StartSlot(0), 0) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) blk1slot1, err := testDetect.SignedBlockHeader(testDetect.StartSlot(0)+1, 0) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) blk1epoch1, err := testDetect.SignedBlockHeader(testDetect.StartSlot(1), 0) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) tests := []testStruct{ { name: "same block sig dont slash", @@ -72,19 +65,11 @@ func TestProposalsDetector_DetectSlashingsForBlockHeaders(t *testing.T) { slasherDB: db, } - if err := sd.slasherDB.SaveBlockHeader(ctx, tt.blk); err != nil { - t.Fatal(err) - } + require.NoError(t, sd.slasherDB.SaveBlockHeader(ctx, tt.blk)) res, err := sd.DetectDoublePropose(ctx, tt.incomingBlk) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(res, tt.slashing) { - t.Errorf("Wanted: %v, received %v", tt.slashing, res) - } - + require.NoError(t, err) + assert.DeepEqual(t, tt.slashing, res) }) } } diff --git a/slasher/detection/proposals/proposals_test.go b/slasher/detection/proposals/proposals_test.go new file mode 100644 index 000000000000..5a72fb898345 --- /dev/null +++ b/slasher/detection/proposals/proposals_test.go @@ -0,0 +1,16 @@ +package proposals + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/sirupsen/logrus" +) + +func TestMain(m *testing.M) { + logrus.SetLevel(logrus.DebugLevel) + logrus.SetOutput(ioutil.Discard) + + os.Exit(m.Run()) +} diff --git a/slasher/node/BUILD.bazel b/slasher/node/BUILD.bazel index 9cb1150cc4a8..8dd3fceb2f65 100644 --- a/slasher/node/BUILD.bazel +++ b/slasher/node/BUILD.bazel @@ -35,6 +35,7 @@ go_test( deps = [ "//shared/testutil:go_default_library", "//shared/testutil/require:go_default_library", + "@com_github_sirupsen_logrus//:go_default_library", "@com_github_sirupsen_logrus//hooks/test:go_default_library", "@com_github_urfave_cli_v2//:go_default_library", ], diff --git a/slasher/node/node_test.go b/slasher/node/node_test.go index 1c09c3c051a8..a0c09bfd9166 100644 --- a/slasher/node/node_test.go +++ b/slasher/node/node_test.go @@ -3,23 +3,30 @@ package node import ( "flag" "fmt" + "io/ioutil" "os" "testing" "github.com/prysmaticlabs/prysm/shared/testutil" "github.com/prysmaticlabs/prysm/shared/testutil/require" + "github.com/sirupsen/logrus" logTest "github.com/sirupsen/logrus/hooks/test" "github.com/urfave/cli/v2" ) +func TestMain(m *testing.M) { + logrus.SetLevel(logrus.DebugLevel) + logrus.SetOutput(ioutil.Discard) + + os.Exit(m.Run()) +} + // Test that slasher node can close. func TestNodeClose_OK(t *testing.T) { hook := logTest.NewGlobal() tmp := fmt.Sprintf("%s/datadirtest2", testutil.TempDir()) - if err := os.RemoveAll(tmp); err != nil { - t.Fatal(err) - } + require.NoError(t, os.RemoveAll(tmp)) app := cli.App{} set := flag.NewFlagSet("test", 0) @@ -29,15 +36,10 @@ func TestNodeClose_OK(t *testing.T) { context := cli.NewContext(&app, set, nil) node, err := NewSlasherNode(context) - if err != nil { - t.Fatalf("Failed to create SlasherNode: %v", err) - } + require.NoError(t, err, "Failed to create slasher node") node.Close() require.LogsContain(t, hook, "Stopping hash slinging slasher") - - if err := os.RemoveAll(tmp); err != nil { - t.Fatal(err) - } + require.NoError(t, os.RemoveAll(tmp)) } diff --git a/slasher/rpc/BUILD.bazel b/slasher/rpc/BUILD.bazel index 2514a67569ac..906538111e85 100644 --- a/slasher/rpc/BUILD.bazel +++ b/slasher/rpc/BUILD.bazel @@ -40,6 +40,7 @@ go_library( go_test( name = "go_default_test", srcs = [ + "rpc_test.go", "server_test.go", "service_test.go", ], @@ -54,6 +55,7 @@ go_test( "//shared/p2putils:go_default_library", "//shared/params:go_default_library", "//shared/testutil:go_default_library", + "//shared/testutil/assert:go_default_library", "//shared/testutil/require:go_default_library", "//slasher/beaconclient:go_default_library", "//slasher/db/testing:go_default_library", diff --git a/slasher/rpc/rpc_test.go b/slasher/rpc/rpc_test.go new file mode 100644 index 000000000000..6d99b8fa6432 --- /dev/null +++ b/slasher/rpc/rpc_test.go @@ -0,0 +1,16 @@ +package rpc + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/sirupsen/logrus" +) + +func TestMain(m *testing.M) { + logrus.SetLevel(logrus.DebugLevel) + logrus.SetOutput(ioutil.Discard) + + os.Exit(m.Run()) +} diff --git a/slasher/rpc/server_test.go b/slasher/rpc/server_test.go index bcfb2012f078..a77c14d9eb9b 100644 --- a/slasher/rpc/server_test.go +++ b/slasher/rpc/server_test.go @@ -17,6 +17,8 @@ import ( "github.com/prysmaticlabs/prysm/shared/p2putils" "github.com/prysmaticlabs/prysm/shared/params" "github.com/prysmaticlabs/prysm/shared/testutil" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" + "github.com/prysmaticlabs/prysm/shared/testutil/require" "github.com/prysmaticlabs/prysm/slasher/beaconclient" testDB "github.com/prysmaticlabs/prysm/slasher/db/testing" "github.com/prysmaticlabs/prysm/slasher/detection" @@ -31,9 +33,7 @@ func TestServer_IsSlashableAttestation(t *testing.T) { ctx := context.Background() _, keys, err := testutil.DeterministicDepositsAndKeys(4) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) wantedValidators1 := ðpb.Validators{ ValidatorList: []*ethpb.Validators_ValidatorContainer{ { @@ -58,9 +58,7 @@ func TestServer_IsSlashableAttestation(t *testing.T) { SlasherDB: db, } fork, err := p2putils.Fork(savedAttestation.Data.Target.Epoch) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) bcCfg := &beaconclient.Config{BeaconClient: bClient, NodeClient: nClient, SlasherDB: db} bs, err := beaconclient.NewBeaconClientService(ctx, bcCfg) @@ -72,9 +70,7 @@ func TestServer_IsSlashableAttestation(t *testing.T) { gomock.Any(), ).Return(wantedValidators1, nil).AnyTimes() domain, err := helpers.Domain(fork, savedAttestation.Data.Target.Epoch, params.BeaconConfig().DomainBeaconAttester, wantedGenesis.GenesisValidatorsRoot) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) wg := sync.WaitGroup{} wg.Add(100) var wentThrough bool @@ -84,17 +80,13 @@ func TestServer_IsSlashableAttestation(t *testing.T) { iatt := state.CopyIndexedAttestation(savedAttestation) iatt.Data.Slot += j root, err := helpers.ComputeSigningRoot(iatt.Data, domain) - if err != nil { - t.Error(err) - } + require.NoError(t, err) var validatorSig bls.Signature validatorSig = keys[iatt.AttestingIndices[0]].Sign(root[:]) marshalledSig := validatorSig.Marshal() iatt.Signature = marshalledSig slashings, err := server.IsSlashableAttestation(ctx, iatt) - if err != nil { - t.Fatalf("got error while trying to detect slashing: %v", err) - } + require.NoError(t, err, "Got error while trying to detect slashing") if len(slashings.AttesterSlashing) == 0 && !wentThrough { wentThrough = true @@ -116,9 +108,7 @@ func TestServer_IsSlashableAttestationNoUpdate(t *testing.T) { ctx := context.Background() _, keys, err := testutil.DeterministicDepositsAndKeys(4) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) wantedValidators1 := ðpb.Validators{ ValidatorList: []*ethpb.Validators_ValidatorContainer{ { @@ -153,17 +143,11 @@ func TestServer_IsSlashableAttestationNoUpdate(t *testing.T) { SlasherDB: db, } fork, err := p2putils.Fork(savedAttestation.Data.Target.Epoch) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) domain, err := helpers.Domain(fork, savedAttestation.Data.Target.Epoch, params.BeaconConfig().DomainBeaconAttester, wantedGenesis.GenesisValidatorsRoot) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := helpers.ComputeSigningRoot(savedAttestation.Data, domain) - if err != nil { - t.Error(err) - } + require.NoError(t, err) sig := []bls.Signature{} for _, idx := range savedAttestation.AttestingIndices { validatorSig := keys[idx].Sign(root[:]) @@ -179,19 +163,11 @@ func TestServer_IsSlashableAttestationNoUpdate(t *testing.T) { ds := detection.NewDetectionService(ctx, cfg) server := Server{ctx: ctx, detector: ds, slasherDB: db, beaconClient: bs} slashings, err := server.IsSlashableAttestation(ctx, savedAttestation) - if err != nil { - t.Fatalf("got error while trying to detect slashing: %v", err) - } - if len(slashings.AttesterSlashing) != 0 { - t.Fatalf("Found slashings while no slashing should have been found on first attestation: %v slashing found: %v", savedAttestation, slashings) - } + require.NoError(t, err, "Got error while trying to detect slashing") + require.Equal(t, 0, len(slashings.AttesterSlashing), "Found slashings while no slashing should have been found on first attestation") sl, err := server.IsSlashableAttestationNoUpdate(ctx, incomingAtt) - if err != nil { - t.Fatalf("got error while trying to detect slashing: %v", err) - } - if sl.Slashable != true { - t.Fatalf("attestation should be found to be slashable. got: %v", sl.Slashable) - } + require.NoError(t, err, "Got error while trying to detect slashing") + require.Equal(t, true, sl.Slashable, "Attestation should be found to be slashable") } func TestServer_IsSlashableBlock(t *testing.T) { @@ -203,9 +179,7 @@ func TestServer_IsSlashableBlock(t *testing.T) { ctx := context.Background() _, keys, err := testutil.DeterministicDepositsAndKeys(4) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) wantedValidators := ðpb.Validators{ ValidatorList: []*ethpb.Validators_ValidatorContainer{ { @@ -235,13 +209,9 @@ func TestServer_IsSlashableBlock(t *testing.T) { } savedBlockEpoch := helpers.SlotToEpoch(savedBlock.Header.Slot) fork, err := p2putils.Fork(savedBlockEpoch) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) domain, err := helpers.Domain(fork, savedBlockEpoch, params.BeaconConfig().DomainBeaconProposer, wantedGenesis.GenesisValidatorsRoot) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) bcCfg := &beaconclient.Config{BeaconClient: bClient, NodeClient: nClient, SlasherDB: db} bs, err := beaconclient.NewBeaconClientService(ctx, bcCfg) @@ -257,18 +227,12 @@ func TestServer_IsSlashableBlock(t *testing.T) { sbbh := state.CopySignedBeaconBlockHeader(savedBlock) sbbh.Header.BodyRoot = bytesutil.PadTo([]byte(fmt.Sprintf("%d", j)), 32) bhr, err := stateutil.BlockHeaderRoot(sbbh.Header) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) root, err := helpers.ComputeSigningRoot(bhr, domain) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) sbbh.Signature = keys[sbbh.Header.ProposerIndex].Sign(root[:]).Marshal() slashings, err := server.IsSlashableBlock(ctx, sbbh) - if err != nil { - t.Fatalf("got error while trying to detect slashing: %v", err) - } + require.NoError(t, err, "Got error while trying to detect slashing") if len(slashings.ProposerSlashing) == 0 && !wentThrough { wentThrough = true } else if len(slashings.ProposerSlashing) == 0 && wentThrough { @@ -277,7 +241,6 @@ func TestServer_IsSlashableBlock(t *testing.T) { }(i) } wg.Wait() - } func TestServer_IsSlashableBlockNoUpdate(t *testing.T) { @@ -289,9 +252,7 @@ func TestServer_IsSlashableBlockNoUpdate(t *testing.T) { ctx := context.Background() _, keys, err := testutil.DeterministicDepositsAndKeys(4) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) wantedValidators := ðpb.Validators{ ValidatorList: []*ethpb.Validators_ValidatorContainer{ { @@ -325,21 +286,13 @@ func TestServer_IsSlashableBlockNoUpdate(t *testing.T) { } savedBlockEpoch := helpers.SlotToEpoch(savedBlock.Header.Slot) fork, err := p2putils.Fork(savedBlockEpoch) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) domain, err := helpers.Domain(fork, savedBlockEpoch, params.BeaconConfig().DomainBeaconProposer, wantedGenesis.GenesisValidatorsRoot) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) bhr, err := stateutil.BlockHeaderRoot(savedBlock.Header) - if err != nil { - t.Error(err) - } + require.NoError(t, err) root, err := helpers.ComputeSigningRoot(bhr, domain) - if err != nil { - t.Error(err) - } + require.NoError(t, err) blockSig := keys[savedBlock.Header.ProposerIndex].Sign(root[:]) marshalledSig := blockSig.Marshal() savedBlock.Signature = marshalledSig @@ -348,17 +301,9 @@ func TestServer_IsSlashableBlockNoUpdate(t *testing.T) { ds := detection.NewDetectionService(ctx, cfg) server := Server{ctx: ctx, detector: ds, slasherDB: db, beaconClient: bs} slashings, err := server.IsSlashableBlock(ctx, savedBlock) - if err != nil { - t.Fatalf("got error while trying to detect slashing: %v", err) - } - if len(slashings.ProposerSlashing) != 0 { - t.Fatalf("Found slashings while no slashing should have been found on first block: %v slashing found: %v", savedBlock, slashings) - } + require.NoError(t, err, "Got error while trying to detect slashing") + require.Equal(t, 0, len(slashings.ProposerSlashing), "Found slashings while no slashing should have been found on first block") sl, err := server.IsSlashableBlockNoUpdate(ctx, incomingBlock) - if err != nil { - t.Fatalf("got error while trying to detect slashing: %v", err) - } - if sl.Slashable != true { - t.Fatalf("block should be found to be slashable. got: %v", sl.Slashable) - } + require.NoError(t, err, "Got error while trying to detect slashing") + require.Equal(t, true, sl.Slashable, "Block should be found to be slashable") } diff --git a/slasher/rpc/service_test.go b/slasher/rpc/service_test.go index 67892ea317b3..4965b78418ad 100644 --- a/slasher/rpc/service_test.go +++ b/slasher/rpc/service_test.go @@ -4,19 +4,13 @@ import ( "context" "errors" "fmt" - "io/ioutil" "testing" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" "github.com/prysmaticlabs/prysm/shared/testutil/require" - "github.com/sirupsen/logrus" logTest "github.com/sirupsen/logrus/hooks/test" ) -func init() { - logrus.SetLevel(logrus.DebugLevel) - logrus.SetOutput(ioutil.Discard) -} - func TestLifecycle_OK(t *testing.T) { hook := logTest.NewGlobal() rpcService := NewService(context.Background(), &Config{ @@ -28,19 +22,14 @@ func TestLifecycle_OK(t *testing.T) { rpcService.Start() require.LogsContain(t, hook, "listening on port") - - if err := rpcService.Stop(); err != nil { - t.Error(err) - } + require.NoError(t, rpcService.Stop()) } func TestStatus_CredentialError(t *testing.T) { credentialErr := errors.New("credentialError") s := &Service{credentialError: credentialErr} - if err := s.Status(); err != s.credentialError { - t.Errorf("Wanted: %v, got: %v", s.credentialError, s.Status()) - } + assert.ErrorContains(t, s.credentialError.Error(), s.Status()) } func TestRPC_InsecureEndpoint(t *testing.T) { @@ -52,8 +41,5 @@ func TestRPC_InsecureEndpoint(t *testing.T) { rpcService.Start() require.LogsContain(t, hook, fmt.Sprint("listening on port")) - - if err := rpcService.Stop(); err != nil { - t.Error(err) - } + require.NoError(t, rpcService.Stop()) } diff --git a/slasher/usage_test.go b/slasher/usage_test.go index 31174285a342..eccba4aa9c0c 100644 --- a/slasher/usage_test.go +++ b/slasher/usage_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/prysmaticlabs/prysm/shared/featureconfig" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" "github.com/urfave/cli/v2" ) @@ -20,16 +21,11 @@ func TestAllFlagsExistInHelp(t *testing.T) { appFlags = featureconfig.ActiveFlags(appFlags) for _, flag := range appFlags { - if !doesFlagExist(flag, helpFlags) { - t.Errorf("Flag %s does not exist in help/usage flags.", flag.Names()[0]) - } + assert.Equal(t, true, doesFlagExist(flag, helpFlags), "Flag %s does not exist in help/usage flags.", flag.Names()[0]) } for _, flag := range helpFlags { - if !doesFlagExist(flag, appFlags) { - t.Errorf("Flag %s does not exist in main.go, "+ - "but exists in help flags", flag.Names()[0]) - } + assert.Equal(t, true, doesFlagExist(flag, appFlags), "Flag %s does not exist in main.go, but exists in help flags", flag.Names()[0]) } }