diff --git a/rolling-shutter/cmd/cryptocmd/jsontests.go b/rolling-shutter/cmd/cryptocmd/jsontests.go index 57abc0288..8dca38740 100644 --- a/rolling-shutter/cmd/cryptocmd/jsontests.go +++ b/rolling-shutter/cmd/cryptocmd/jsontests.go @@ -315,8 +315,10 @@ var testSpecs = []struct { } func createJSONTests(enc testEncoder) { - keygen := testkeygen.NewKeyGenerator(12, 10) - var err error + keys, err := testkeygen.NewEonKeys(rand.Reader, 12, 10) + if err != nil { + panic(err) + } for i := range testSpecs { testSpec := testSpecs[i] @@ -329,7 +331,7 @@ func createJSONTests(enc testEncoder) { if err != nil { panic(err) } - et, err := createEncryptionTest(keygen, testSpec.payload) + et, err := createEncryptionTest(keys, testSpec.payload) if err != nil { panic(err) } @@ -345,7 +347,7 @@ func createJSONTests(enc testEncoder) { } enc.addTest(&testcase) - dt := createDecryptionTest(keygen, *et) + dt := createDecryptionTest(keys, *et) testcase = testCase{ testCaseMeta: testCaseMeta{ Description: testSpec.description, @@ -360,13 +362,13 @@ func createJSONTests(enc testEncoder) { enc.addTest(&testcase) case tampered: - et, err := createEncryptionTest(keygen, testSpec.payload) + et, err := createEncryptionTest(keys, testSpec.payload) if err != nil { panic(err) } - tamperedEt := tamperEncryptedMessage(keygen, *et) + tamperedEt := tamperEncryptedMessage(keys, *et) - dt := createDecryptionTest(keygen, tamperedEt) + dt := createDecryptionTest(keys, tamperedEt) dt.Expected, _ = hexutil.Decode("0x") testcase := testCase{ testCaseMeta: testCaseMeta{ @@ -383,9 +385,9 @@ func createJSONTests(enc testEncoder) { var err error var vt verificationTest if testSpec.style == verifying { - vt, err = createVerificationTest(keygen, testSpec.payload) + vt, err = createVerificationTest(keys, testSpec.payload) } else { - vt, err = createFailedVerificationTest(keygen, testSpec.payload) + vt, err = createFailedVerificationTest(keys, testSpec.payload) } if err != nil { panic(err) @@ -415,16 +417,19 @@ func verifyTestCase(tc *testCase) error { return tc.Test.Run() } -func createEncryptionTest(keygen *testkeygen.KeyGenerator, message []byte) (*encryptionTest, error) { - epochID := keygen.RandomEpochID(make([]byte, 52)) +func createEncryptionTest(keys *testkeygen.EonKeys, message []byte) (*encryptionTest, error) { + epochID, err := randomEpochID() + if err != nil { + return nil, err + } et := encryptionTest{} et.Message = message - et.EonPublicKey = keygen.EonPublicKey(epochID) + et.EonPublicKey = keys.EonPublicKey() et.EpochID = epochID - sigma, err := keygen.RandomSigma() + sigma, err := shcrypto.RandomSigma(rand.Reader) if err != nil { return &et, err } @@ -434,7 +439,7 @@ func createEncryptionTest(keygen *testkeygen.KeyGenerator, message []byte) (*enc encryptedMessage := shcrypto.Encrypt( et.Message, - keygen.EonPublicKey(epochID), + keys.EonPublicKey(), epochIDPoint, sigma, ) @@ -450,11 +455,13 @@ func createEncryptionTest(keygen *testkeygen.KeyGenerator, message []byte) (*enc } // tamperEncryptedMessage changes the C1 value of EncryptedMessage, which allows to test for malleability issues. -func tamperEncryptedMessage(keygen *testkeygen.KeyGenerator, et encryptionTest) encryptionTest { - decryptionKey := keygen.EpochSecretKey(et.EpochID) +func tamperEncryptedMessage(keys *testkeygen.EonKeys, et encryptionTest) encryptionTest { + decryptionKey, err := keys.EpochSecretKey(et.EpochID) + if err != nil { + panic(err) + } g2 := bls12381.NewG2() var c1 *bls12381.PointG2 - var err error for i := 1; i <= 10000; i++ { c1 = et.Expected.C1 @@ -471,9 +478,12 @@ func tamperEncryptedMessage(keygen *testkeygen.KeyGenerator, et encryptionTest) return et } -func createDecryptionTest(keygen *testkeygen.KeyGenerator, et encryptionTest) decryptionTest { +func createDecryptionTest(keys *testkeygen.EonKeys, et encryptionTest) decryptionTest { dt := decryptionTest{} - epochSecretKey := keygen.EpochSecretKey(et.EpochID) + epochSecretKey, err := keys.EpochSecretKey(et.EpochID) + if err != nil { + panic(err) + } dt.EpochSecretKey = *epochSecretKey dt.Cipher = *et.Expected @@ -483,13 +493,20 @@ func createDecryptionTest(keygen *testkeygen.KeyGenerator, et encryptionTest) de return dt } -func createVerificationTest(keygen *testkeygen.KeyGenerator, payload []byte) (verificationTest, error) { +func createVerificationTest(keys *testkeygen.EonKeys, _ []byte) (verificationTest, error) { var err error vt := verificationTest{} - epochID := keygen.RandomEpochID(payload) + epochID, err := randomEpochID() + if err != nil { + return verificationTest{}, err + } vt.EpochID = epochID - vt.EpochSecretKey = *keygen.EpochSecretKey(epochID) - vt.EonPublicKey = *keygen.EonPublicKey(epochID) + epochSecretKey, err := keys.EpochSecretKey(epochID) + if err != nil { + return verificationTest{}, err + } + vt.EpochSecretKey = *epochSecretKey + vt.EonPublicKey = *keys.EonPublicKey() vt.Expected, err = shcrypto.VerifyEpochSecretKey( &vt.EpochSecretKey, &vt.EonPublicKey, @@ -498,14 +515,25 @@ func createVerificationTest(keygen *testkeygen.KeyGenerator, payload []byte) (ve return vt, err } -func createFailedVerificationTest(keygen *testkeygen.KeyGenerator, _ []byte) (verificationTest, error) { +func createFailedVerificationTest(keys *testkeygen.EonKeys, _ []byte) (verificationTest, error) { var err error vt := verificationTest{} - epochID := keygen.RandomEpochID(make([]byte, 52)) - mismatch := keygen.RandomEpochID(make([]byte, 52)) + epochID, err := randomEpochID() + if err != nil { + return verificationTest{}, err + } vt.EpochID = epochID - vt.EpochSecretKey = *keygen.EpochSecretKey(epochID) - vt.EonPublicKey = *keygen.EonPublicKey(mismatch) + epochSecretKey, err := keys.EpochSecretKey(epochID) + if err != nil { + return verificationTest{}, err + } + vt.EpochSecretKey = *epochSecretKey + + keysMismatch, err := testkeygen.NewEonKeys(rand.Reader, 12, 10) + if err != nil { + return verificationTest{}, err + } + vt.EonPublicKey = *keysMismatch.EonPublicKey() vt.Expected, err = shcrypto.VerifyEpochSecretKey( &vt.EpochSecretKey, &vt.EonPublicKey, @@ -598,3 +626,21 @@ func testMarshalingRoundtrip(tc *testCase) error { } return nil } + +func randomEpochID() (identitypreimage.IdentityPreimage, error) { + epochID := make([]byte, 52) + _, err := rand.Read(epochID) + if err != nil { + return identitypreimage.IdentityPreimage{}, err + } + return identitypreimage.IdentityPreimage(epochID), nil +} + +func randomSigma() shcrypto.Block { + sigma := make([]byte, 32) + _, err := rand.Read(sigma) + if err != nil { + panic(err) + } + return shcrypto.Block(sigma) +} diff --git a/rolling-shutter/collator/eonhandling_test.go b/rolling-shutter/collator/eonhandling_test.go index 8a53d8194..a94dd6073 100644 --- a/rolling-shutter/collator/eonhandling_test.go +++ b/rolling-shutter/collator/eonhandling_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/ecdsa" + "crypto/rand" "testing" "time" @@ -15,7 +16,6 @@ import ( "github.com/shutter-network/rolling-shutter/rolling-shutter/collator/database" "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/configuration" enctime "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/encodeable/time" - "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/identitypreimage" "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/testkeygen" "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/testsetup" "github.com/shutter-network/rolling-shutter/rolling-shutter/p2p" @@ -140,16 +140,14 @@ func TestHandleEonKeyIntegration(t *testing.T) { db := database.New(dbpool) testConfig := newTestConfig(t) - tkgBefore := testkeygen.NewTestKeyGenerator(t, 3, 2, false) - tkg := testkeygen.NewTestKeyGenerator(t, 3, 2, false) - - identityPreimage1 := identitypreimage.Uint64ToIdentityPreimage(1) - identityPreimage1000 := identitypreimage.Uint64ToIdentityPreimage(1000) - identityPreimage2000 := identitypreimage.Uint64ToIdentityPreimage(2000) + keysBefore, err := testkeygen.NewEonKeys(rand.Reader, 3, 2) + assert.NilError(t, err) + keys, err := testkeygen.NewEonKeys(rand.Reader, 3, 2) + assert.NilError(t, err) - eonPubKeyNoThreshold, _ = tkgBefore.EonPublicKey(identityPreimage1).GobEncode() - eonPubKeyBefore, _ = tkgBefore.EonPublicKey(identityPreimage1000).GobEncode() - eonPubKey, _ = tkg.EonPublicKey(identityPreimage2000).GobEncode() + eonPubKeyNoThreshold, _ = keysBefore.EonPublicKey().GobEncode() + eonPubKeyBefore, _ = keysBefore.EonPublicKey().GobEncode() + eonPubKey, _ = keys.EonPublicKey().GobEncode() kpr1, _ := ethcrypto.GenerateKey() kpr2, _ := ethcrypto.GenerateKey() @@ -166,7 +164,7 @@ func TestHandleEonKeyIntegration(t *testing.T) { keyperConfigIndex: uint64(0), activationBlock: activationBlockNoThreshold, eonPubKey: eonPubKeyNoThreshold, - threshold: tkg.Threshold, + threshold: keys.Threshold, keypers: []*ecdsa.PrivateKey{kpr1}, }) assert.Check(t, len(keypersNoThreshold) > 0) @@ -179,7 +177,7 @@ func TestHandleEonKeyIntegration(t *testing.T) { keyperConfigIndex: uint64(1), activationBlock: activationBlockBefore, eonPubKey: eonPubKeyBefore, - threshold: tkg.Threshold, + threshold: keys.Threshold, keypers: []*ecdsa.PrivateKey{kpr1, kpr2, kpr3}, }) assert.Check(t, len(keypersBefore) > 0) @@ -190,7 +188,7 @@ func TestHandleEonKeyIntegration(t *testing.T) { keyperConfigIndex: uint64(2), activationBlock: activationBlock, eonPubKey: eonPubKey, - threshold: tkg.Threshold, + threshold: keys.Threshold, keypers: []*ecdsa.PrivateKey{kpr3, kpr1, kpr2}, }) assert.Check(t, len(keypers) > 0) diff --git a/rolling-shutter/keyper/epochkghandler/benchmark_test.go b/rolling-shutter/keyper/epochkghandler/benchmark_test.go new file mode 100644 index 000000000..fa26c4e2b --- /dev/null +++ b/rolling-shutter/keyper/epochkghandler/benchmark_test.go @@ -0,0 +1,228 @@ +package epochkghandler + +import ( + "context" + "math/big" + "testing" + + "github.com/jackc/pgx/v4/pgxpool" + pubsub "github.com/libp2p/go-libp2p-pubsub" + "github.com/rs/zerolog" + "gotest.tools/assert" + + "github.com/shutter-network/rolling-shutter/rolling-shutter/keyper/database" + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/identitypreimage" + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/testkeygen" + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/testsetup" + "github.com/shutter-network/rolling-shutter/rolling-shutter/p2p" + "github.com/shutter-network/rolling-shutter/rolling-shutter/p2pmsg" +) + +// The number of identity preimages to generate for each of the benchmark runs. Note that this must +// be smaller than MaxNumKeysPerMessage, otherwise the benchmarks will fail. +const numIdentityPreimages = 100 + +func prepareBenchmark(ctx context.Context, b *testing.B, dbpool *pgxpool.Pool) (*testkeygen.EonKeys, []identitypreimage.IdentityPreimage) { + b.Helper() + zerolog.SetGlobalLevel(zerolog.Disabled) + keyperIndex := uint64(1) + identityPreimages := []identitypreimage.IdentityPreimage{} + for i := 0; i < numIdentityPreimages; i++ { + b := make([]byte, 52) + big.NewInt(int64(i)).FillBytes(b) + identityPreimage := identitypreimage.IdentityPreimage(b) + identityPreimages = append(identityPreimages, identityPreimage) + } + + keys := testsetup.InitializeEon(ctx, b, dbpool, config, keyperIndex) + return keys, identityPreimages +} + +func prepareKeysBenchmark(ctx context.Context, b *testing.B, dbpool *pgxpool.Pool) (p2p.MessageHandler, *p2pmsg.DecryptionKeys) { + b.Helper() + zerolog.SetGlobalLevel(zerolog.Disabled) + keys, identityPreimages := prepareBenchmark(ctx, b, dbpool) + + encodedDecryptionKeys := [][]byte{} + for _, identityPreimage := range identityPreimages { + decryptionKey, err := keys.EpochSecretKey(identityPreimage) + assert.NilError(b, err) + encodedDecryptionKey := decryptionKey.Marshal() + encodedDecryptionKeys = append(encodedDecryptionKeys, encodedDecryptionKey) + } + decryptionKeys := []*p2pmsg.Key{} + for i, identityPreimage := range identityPreimages { + key := &p2pmsg.Key{ + Identity: identityPreimage.Bytes(), + Key: encodedDecryptionKeys[i], + } + decryptionKeys = append(decryptionKeys, key) + } + msg := &p2pmsg.DecryptionKeys{ + InstanceID: config.GetInstanceID(), + Eon: 1, + Keys: decryptionKeys, + } + + var handler p2p.MessageHandler = &DecryptionKeyHandler{config: config, dbpool: dbpool} + + return handler, msg +} + +func prepareKeySharesBenchmark( + ctx context.Context, + b *testing.B, + dbpool *pgxpool.Pool, + isSecond bool, +) (p2p.MessageHandler, *p2pmsg.DecryptionKeyShares) { + b.Helper() + zerolog.SetGlobalLevel(zerolog.Disabled) + keys, identityPreimages := prepareBenchmark(ctx, b, dbpool) + var handler p2p.MessageHandler = &DecryptionKeyShareHandler{config: config, dbpool: dbpool} + + if isSecond { + shares := []*p2pmsg.KeyShare{} + keyperIndex := 0 + for _, identityPreimage := range identityPreimages { + share := &p2pmsg.KeyShare{ + EpochID: identityPreimage.Bytes(), + Share: keys.EpochSecretKeyShare(identityPreimage, keyperIndex).Marshal(), + } + shares = append(shares, share) + } + msg := &p2pmsg.DecryptionKeyShares{ + InstanceID: config.GetInstanceID(), + Eon: 1, + KeyperIndex: uint64(keyperIndex), + Shares: shares, + } + validationResult, err := handler.ValidateMessage(ctx, msg) + assert.NilError(b, err) + assert.Check(b, validationResult == pubsub.ValidationAccept) + _, err = handler.HandleMessage(ctx, msg) + assert.NilError(b, err) + } + + keyperIndex := 2 + shares := []*p2pmsg.KeyShare{} + for _, identityPreimage := range identityPreimages { + share := &p2pmsg.KeyShare{ + EpochID: identityPreimage.Bytes(), + Share: keys.EpochSecretKeyShare(identityPreimage, keyperIndex).Marshal(), + } + shares = append(shares, share) + } + msg := &p2pmsg.DecryptionKeyShares{ + InstanceID: config.GetInstanceID(), + Eon: 1, + KeyperIndex: uint64(keyperIndex), + Shares: shares, + } + + return handler, msg +} + +func BenchmarkValidateKeysIntegration(b *testing.B) { + ctx := context.Background() + dbpool, dbclose := testsetup.NewTestDBPool(ctx, b, database.Definition) + b.Cleanup(dbclose) + handler, msg := prepareKeysBenchmark(ctx, b, dbpool) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StartTimer() + validationResult, err := handler.ValidateMessage(ctx, msg) + b.StopTimer() + assert.NilError(b, err) + assert.Check(b, validationResult == pubsub.ValidationAccept) + } +} + +func BenchmarkHandleKeysIntegration(b *testing.B) { + ctx := context.Background() + dbpool, dbclose := testsetup.NewTestDBPool(ctx, b, database.Definition) + b.Cleanup(dbclose) + handler, msg := prepareKeysBenchmark(ctx, b, dbpool) + + validationResult, err := handler.ValidateMessage(ctx, msg) + assert.NilError(b, err) + assert.Check(b, validationResult == pubsub.ValidationAccept) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StartTimer() + _, err = handler.HandleMessage(ctx, msg) + b.StopTimer() + assert.NilError(b, err) + } +} + +func BenchmarkValidateFirstKeySharesIntegration(b *testing.B) { + ctx := context.Background() + dbpool, dbclose := testsetup.NewTestDBPool(ctx, b, database.Definition) + b.Cleanup(dbclose) + handler, msg := prepareKeySharesBenchmark(ctx, b, dbpool, false) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StartTimer() + validationResult, err := handler.ValidateMessage(ctx, msg) + b.StopTimer() + assert.NilError(b, err) + assert.Check(b, validationResult == pubsub.ValidationAccept) + } +} + +func BenchmarkHandleFirstKeySharesIntegration(b *testing.B) { + ctx := context.Background() + dbpool, dbclose := testsetup.NewTestDBPool(ctx, b, database.Definition) + b.Cleanup(dbclose) + handler, msg := prepareKeySharesBenchmark(ctx, b, dbpool, false) + + validationResult, err := handler.ValidateMessage(ctx, msg) + assert.NilError(b, err) + assert.Check(b, validationResult == pubsub.ValidationAccept) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StartTimer() + _, err = handler.HandleMessage(ctx, msg) + b.StopTimer() + assert.NilError(b, err) + } +} + +func BenchmarkValidateSecondKeySharesIntegration(b *testing.B) { + ctx := context.Background() + dbpool, dbclose := testsetup.NewTestDBPool(ctx, b, database.Definition) + b.Cleanup(dbclose) + handler, msg := prepareKeySharesBenchmark(ctx, b, dbpool, true) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StartTimer() + validationResult, err := handler.ValidateMessage(ctx, msg) + b.StopTimer() + assert.NilError(b, err) + assert.Check(b, validationResult == pubsub.ValidationAccept) + } +} + +func BenchmarkHandleSecondKeySharesIntegration(b *testing.B) { + ctx := context.Background() + dbpool, dbclose := testsetup.NewTestDBPool(ctx, b, database.Definition) + b.Cleanup(dbclose) + handler, msg := prepareKeySharesBenchmark(ctx, b, dbpool, true) + + validationResult, err := handler.ValidateMessage(ctx, msg) + assert.NilError(b, err) + assert.Check(b, validationResult == pubsub.ValidationAccept) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StartTimer() + _, err = handler.HandleMessage(ctx, msg) + b.StopTimer() + assert.NilError(b, err) + } +} diff --git a/rolling-shutter/keyper/epochkghandler/key_test.go b/rolling-shutter/keyper/epochkghandler/key_test.go index 12ca83b92..16e780bf0 100644 --- a/rolling-shutter/keyper/epochkghandler/key_test.go +++ b/rolling-shutter/keyper/epochkghandler/key_test.go @@ -28,41 +28,43 @@ func TestHandleDecryptionKeyIntegration(t *testing.T) { queries := database.New(dbpool) - eon := config.GetEon() identityPreimages := []identitypreimage.IdentityPreimage{} for i := 0; i < 3; i++ { identityPreimage := identitypreimage.Uint64ToIdentityPreimage(uint64(i)) identityPreimages = append(identityPreimages, identityPreimage) } keyperIndex := uint64(1) + keyperConfigIndex := uint64(1) - tkg := testsetup.InitializeEon(ctx, t, dbpool, config, keyperIndex) + keys := testsetup.InitializeEon(ctx, t, dbpool, config, keyperIndex) var handler p2p.MessageHandler = &DecryptionKeyHandler{config: config, dbpool: dbpool} encodedDecryptionKeys := [][]byte{} for _, identityPreimage := range identityPreimages { - encodedDecryptionKey := tkg.EpochSecretKey(identityPreimage).Marshal() + decryptionKey, err := keys.EpochSecretKey(identityPreimage) + assert.NilError(t, err) + encodedDecryptionKey := decryptionKey.Marshal() encodedDecryptionKeys = append(encodedDecryptionKeys, encodedDecryptionKey) } // send a decryption key and check that it gets inserted - keys := []*p2pmsg.Key{} + decryptionKeys := []*p2pmsg.Key{} for i, identityPreimage := range identityPreimages { key := &p2pmsg.Key{ Identity: identityPreimage.Bytes(), Key: encodedDecryptionKeys[i], } - keys = append(keys, key) + decryptionKeys = append(decryptionKeys, key) } msgs := p2ptest.MustHandleMessage(t, handler, ctx, &p2pmsg.DecryptionKeys{ InstanceID: config.GetInstanceID(), - Eon: eon, - Keys: keys, + Eon: keyperConfigIndex, + Keys: decryptionKeys, }) assert.Check(t, len(msgs) == 0) for i, identityPreimage := range identityPreimages { key, err := queries.GetDecryptionKey(ctx, database.GetDecryptionKeyParams{ - Eon: int64(eon), + Eon: int64(keyperConfigIndex), EpochID: identityPreimage.Bytes(), }) assert.NilError(t, err) @@ -80,13 +82,15 @@ func TestDecryptionKeyValidatorIntegration(t *testing.T) { t.Cleanup(dbclose) keyperIndex := uint64(1) - eon := config.GetEon() + keyperConfigIndex := uint64(1) identityPreimage := identitypreimage.BigToIdentityPreimage(common.Big0) secondIdentityPreimage := identitypreimage.BigToIdentityPreimage(common.Big1) wrongIdentityPreimage := identitypreimage.BigToIdentityPreimage(common.Big2) - tkg := testsetup.InitializeEon(ctx, t, dbpool, config, keyperIndex) - secretKey := tkg.EpochSecretKey(identityPreimage).Marshal() - secondSecretKey := tkg.EpochSecretKey(secondIdentityPreimage).Marshal() + keys := testsetup.InitializeEon(ctx, t, dbpool, config, keyperIndex) + secretKey, err := keys.EpochSecretKey(identityPreimage) + assert.NilError(t, err) + secondSecretKey, err := keys.EpochSecretKey(secondIdentityPreimage) + assert.NilError(t, err) var handler p2p.MessageHandler = &DecryptionKeyHandler{config: config, dbpool: dbpool} tests := []struct { @@ -99,11 +103,11 @@ func TestDecryptionKeyValidatorIntegration(t *testing.T) { validationResult: pubsub.ValidationAccept, msg: &p2pmsg.DecryptionKeys{ InstanceID: config.GetInstanceID(), - Eon: eon, + Eon: keyperConfigIndex, Keys: []*p2pmsg.Key{ { Identity: identityPreimage.Bytes(), - Key: secretKey, + Key: secretKey.Marshal(), }, }, }, @@ -113,11 +117,11 @@ func TestDecryptionKeyValidatorIntegration(t *testing.T) { validationResult: pubsub.ValidationReject, msg: &p2pmsg.DecryptionKeys{ InstanceID: config.GetInstanceID(), - Eon: eon, + Eon: keyperConfigIndex, Keys: []*p2pmsg.Key{ { Identity: wrongIdentityPreimage.Bytes(), - Key: secretKey, + Key: secretKey.Marshal(), }, }, }, @@ -127,11 +131,11 @@ func TestDecryptionKeyValidatorIntegration(t *testing.T) { validationResult: pubsub.ValidationReject, msg: &p2pmsg.DecryptionKeys{ InstanceID: config.GetInstanceID() + 1, - Eon: eon, + Eon: keyperConfigIndex, Keys: []*p2pmsg.Key{ { Identity: identityPreimage.Bytes(), - Key: secretKey, + Key: secretKey.Marshal(), }, }, }, @@ -141,42 +145,24 @@ func TestDecryptionKeyValidatorIntegration(t *testing.T) { validationResult: pubsub.ValidationReject, msg: &p2pmsg.DecryptionKeys{ InstanceID: config.GetInstanceID(), - Eon: eon, + Eon: keyperConfigIndex, Keys: []*p2pmsg.Key{}, }, }, - { - name: "invalid decryption key duplicate", - validationResult: pubsub.ValidationReject, - msg: &p2pmsg.DecryptionKeys{ - InstanceID: config.GetInstanceID(), - Eon: eon, - Keys: []*p2pmsg.Key{ - { - Identity: identityPreimage.Bytes(), - Key: secretKey, - }, - { - Identity: identityPreimage.Bytes(), - Key: secretKey, - }, - }, - }, - }, { name: "invalid decryption key unordered", validationResult: pubsub.ValidationReject, msg: &p2pmsg.DecryptionKeys{ InstanceID: config.GetInstanceID(), - Eon: eon, + Eon: keyperConfigIndex, Keys: []*p2pmsg.Key{ { Identity: secondIdentityPreimage.Bytes(), - Key: secondSecretKey, + Key: secondSecretKey.Marshal(), }, { Identity: identityPreimage.Bytes(), - Key: secretKey, + Key: secretKey.Marshal(), }, }, }, diff --git a/rolling-shutter/keyper/epochkghandler/keyshare_test.go b/rolling-shutter/keyper/epochkghandler/keyshare_test.go index 8f4a85c90..399289f4f 100644 --- a/rolling-shutter/keyper/epochkghandler/keyshare_test.go +++ b/rolling-shutter/keyper/epochkghandler/keyshare_test.go @@ -32,13 +32,15 @@ func TestHandleDecryptionKeyShareIntegration(t *testing.T) { identityPreimages = append(identityPreimages, identityPreimage) } keyperIndex := uint64(1) + keyperConfigIndex := uint64(1) - tkg := testsetup.InitializeEon(ctx, t, dbpool, config, keyperIndex) + keys := testsetup.InitializeEon(ctx, t, dbpool, config, keyperIndex) var handler p2p.MessageHandler = &DecryptionKeyShareHandler{config: config, dbpool: dbpool} encodedDecryptionKeys := [][]byte{} for _, identityPreimage := range identityPreimages { - encodedDecryptionKey := tkg.EpochSecretKey(identityPreimage).Marshal() - encodedDecryptionKeys = append(encodedDecryptionKeys, encodedDecryptionKey) + encodedDecryptionKey, err := keys.EpochSecretKey(identityPreimage) + assert.NilError(t, err) + encodedDecryptionKeys = append(encodedDecryptionKeys, encodedDecryptionKey.Marshal()) } // threshold is two, so no outgoing message after first input @@ -46,13 +48,13 @@ func TestHandleDecryptionKeyShareIntegration(t *testing.T) { for _, identityPreimage := range identityPreimages { share := &p2pmsg.KeyShare{ EpochID: identityPreimage.Bytes(), - Share: tkg.EpochSecretKeyShare(identityPreimage, 0).Marshal(), + Share: keys.EpochSecretKeyShare(identityPreimage, 0).Marshal(), } shares = append(shares, share) } msgs := p2ptest.MustHandleMessage(t, handler, ctx, &p2pmsg.DecryptionKeyShares{ InstanceID: config.GetInstanceID(), - Eon: config.GetEon(), + Eon: keyperConfigIndex, KeyperIndex: 0, Shares: shares, }) @@ -64,13 +66,13 @@ func TestHandleDecryptionKeyShareIntegration(t *testing.T) { for _, identityPreimage := range identityPreimages { share := &p2pmsg.KeyShare{ EpochID: identityPreimage.Bytes(), - Share: tkg.EpochSecretKeyShare(identityPreimage, 2).Marshal(), + Share: keys.EpochSecretKeyShare(identityPreimage, 2).Marshal(), } shares = append(shares, share) } msgs = p2ptest.MustHandleMessage(t, handler, ctx, &p2pmsg.DecryptionKeyShares{ InstanceID: config.GetInstanceID(), - Eon: config.GetEon(), + Eon: keyperConfigIndex, KeyperIndex: 2, Shares: shares, }) @@ -96,13 +98,13 @@ func TestDecryptionKeyshareValidatorIntegration(t *testing.T) { t.Cleanup(dbclose) keyperIndex := uint64(1) - eon := config.GetEon() + keyperConfigIndex := uint64(1) identityPreimage := identitypreimage.BigToIdentityPreimage(common.Big0) secondIdentityPreimage := identitypreimage.BigToIdentityPreimage(common.Big1) wrongIdentityPreimage := identitypreimage.BigToIdentityPreimage(common.Big2) - tkg := testsetup.InitializeEon(ctx, t, dbpool, config, keyperIndex) - keyshare := tkg.EpochSecretKeyShare(identityPreimage, keyperIndex).Marshal() - secondKeyshare := tkg.EpochSecretKeyShare(secondIdentityPreimage, keyperIndex).Marshal() + keys := testsetup.InitializeEon(ctx, t, dbpool, config, keyperIndex) + keyshare := keys.EpochSecretKeyShare(identityPreimage, int(keyperIndex)).Marshal() + secondKeyshare := keys.EpochSecretKeyShare(secondIdentityPreimage, int(keyperIndex)).Marshal() var handler p2p.MessageHandler = &DecryptionKeyShareHandler{config: config, dbpool: dbpool} tests := []struct { @@ -115,7 +117,7 @@ func TestDecryptionKeyshareValidatorIntegration(t *testing.T) { validationResult: pubsub.ValidationAccept, msg: &p2pmsg.DecryptionKeyShares{ InstanceID: config.GetInstanceID(), - Eon: eon, + Eon: keyperConfigIndex, KeyperIndex: keyperIndex, Shares: []*p2pmsg.KeyShare{ { @@ -134,7 +136,7 @@ func TestDecryptionKeyshareValidatorIntegration(t *testing.T) { validationResult: pubsub.ValidationReject, msg: &p2pmsg.DecryptionKeyShares{ InstanceID: config.GetInstanceID(), - Eon: eon, + Eon: keyperConfigIndex, KeyperIndex: keyperIndex, Shares: []*p2pmsg.KeyShare{ { @@ -149,7 +151,7 @@ func TestDecryptionKeyshareValidatorIntegration(t *testing.T) { validationResult: pubsub.ValidationReject, msg: &p2pmsg.DecryptionKeyShares{ InstanceID: config.GetInstanceID() + 1, - Eon: eon, + Eon: keyperConfigIndex, KeyperIndex: keyperIndex, Shares: []*p2pmsg.KeyShare{ { @@ -164,7 +166,7 @@ func TestDecryptionKeyshareValidatorIntegration(t *testing.T) { validationResult: pubsub.ValidationReject, msg: &p2pmsg.DecryptionKeyShares{ InstanceID: config.GetInstanceID(), - Eon: eon, + Eon: keyperConfigIndex, KeyperIndex: keyperIndex + 1, Shares: []*p2pmsg.KeyShare{ { @@ -179,36 +181,17 @@ func TestDecryptionKeyshareValidatorIntegration(t *testing.T) { validationResult: pubsub.ValidationReject, msg: &p2pmsg.DecryptionKeyShares{ InstanceID: config.GetInstanceID(), - Eon: eon, + Eon: keyperConfigIndex, KeyperIndex: keyperIndex, Shares: []*p2pmsg.KeyShare{}, }, }, - { - name: "invalid decryption key share duplicate", - validationResult: pubsub.ValidationReject, - msg: &p2pmsg.DecryptionKeyShares{ - InstanceID: config.GetInstanceID(), - Eon: eon, - KeyperIndex: keyperIndex, - Shares: []*p2pmsg.KeyShare{ - { - EpochID: identityPreimage.Bytes(), - Share: keyshare, - }, - { - EpochID: identityPreimage.Bytes(), - Share: keyshare, - }, - }, - }, - }, { name: "invalid decryption key share unordered", validationResult: pubsub.ValidationReject, msg: &p2pmsg.DecryptionKeyShares{ InstanceID: config.GetInstanceID(), - Eon: eon, + Eon: keyperConfigIndex, KeyperIndex: keyperIndex, Shares: []*p2pmsg.KeyShare{ { diff --git a/rolling-shutter/keyper/epochkghandler/trigger_test.go b/rolling-shutter/keyper/epochkghandler/trigger_test.go index a07dc83eb..736482975 100644 --- a/rolling-shutter/keyper/epochkghandler/trigger_test.go +++ b/rolling-shutter/keyper/epochkghandler/trigger_test.go @@ -29,6 +29,7 @@ func TestHandleDecryptionTriggerIntegration(t *testing.T) { identityPreimage := identitypreimage.Uint64ToIdentityPreimage(50) keyperIndex := uint64(1) + keyperConfigIndex := int64(1) testsetup.InitializeEon(ctx, t, dbpool, config, keyperIndex) @@ -55,7 +56,6 @@ func TestHandleDecryptionTriggerIntegration(t *testing.T) { IdentityPreimages: []identitypreimage.IdentityPreimage{identityPreimage}, } decrTrigChan <- broker.NewEvent(trig) - decrTrigChan <- broker.NewEvent(trig) close(decrTrigChan) err = group.Wait() cleanup() @@ -63,14 +63,11 @@ func TestHandleDecryptionTriggerIntegration(t *testing.T) { // send decryption key share when first trigger is received share, err := db.GetDecryptionKeyShare(ctx, database.GetDecryptionKeyShareParams{ - Eon: int64(config.GetEon()), + Eon: keyperConfigIndex, EpochID: identityPreimage.Bytes(), KeyperIndex: int64(keyperIndex), }) assert.NilError(t, err) - // although we requested the trigger 2 times, the keyshare should have been - // sent only once - assert.Check(t, len(messaging.SentMessages) == 1) msg, ok := messaging.SentMessages[0].Message.(*p2pmsg.DecryptionKeyShares) assert.Check(t, ok) diff --git a/rolling-shutter/medley/testkeygen/eonkeys.go b/rolling-shutter/medley/testkeygen/eonkeys.go index 0d7e87cf0..b5033f914 100644 --- a/rolling-shutter/medley/testkeygen/eonkeys.go +++ b/rolling-shutter/medley/testkeygen/eonkeys.go @@ -11,8 +11,8 @@ import ( // KeyperKeyShares holds the public and private key shares of a single keyper. type KeyperKeyShares struct { - eonPublicKeyShare *shcrypto.EonPublicKeyShare - eonSecretKeyShare *shcrypto.EonSecretKeyShare + EonPublicKeyShare *shcrypto.EonPublicKeyShare + EonSecretKeyShare *shcrypto.EonSecretKeyShare } // ComputeEpochSecretKeyShare computes the secret key share for the given epoch. @@ -20,7 +20,7 @@ func (kks *KeyperKeyShares) ComputeEpochSecretKeyShare( identityPreimage identitypreimage.IdentityPreimage, ) *shcrypto.EpochSecretKeyShare { epochIDG1 := shcrypto.ComputeEpochID(identityPreimage.Bytes()) - return shcrypto.ComputeEpochSecretKeyShare(kks.eonSecretKeyShare, epochIDG1) + return shcrypto.ComputeEpochSecretKeyShare(kks.EonSecretKeyShare, epochIDG1) } // EonKeys holds all keys for one eon. @@ -54,8 +54,8 @@ func NewEonKeys(random io.Reader, numKeypers uint64, threshold uint64) (*EonKeys vs = append(vs, v) } shares = append(shares, KeyperKeyShares{ - eonSecretKeyShare: shcrypto.ComputeEonSecretKeyShare(vs), - eonPublicKeyShare: shcrypto.ComputeEonPublicKeyShare(i, gammas), + EonSecretKeyShare: shcrypto.ComputeEonSecretKeyShare(vs), + EonPublicKeyShare: shcrypto.ComputeEonPublicKeyShare(i, gammas), }) } @@ -67,6 +67,18 @@ func NewEonKeys(random io.Reader, numKeypers uint64, threshold uint64) (*EonKeys }, nil } +func (eonkeys *EonKeys) EonPublicKey() *shcrypto.EonPublicKey { + return eonkeys.publicKey +} + +func (eonkeys *EonKeys) EonPublicKeyShare(keyperIndex int) *shcrypto.EonPublicKeyShare { + return eonkeys.keyperShares[keyperIndex].EonPublicKeyShare +} + +func (eonkeys *EonKeys) EonSecretKeyShare(keyperIndex int) *shcrypto.EonSecretKeyShare { + return eonkeys.keyperShares[keyperIndex].EonSecretKeyShare +} + func (eonkeys *EonKeys) getEpochSecretKeyShares( identityPreimage identitypreimage.IdentityPreimage, keyperIndices []int, @@ -78,12 +90,21 @@ func (eonkeys *EonKeys) getEpochSecretKeyShares( return res } +func (eonkeys *EonKeys) EpochSecretKeyShare( + identityPreimage identitypreimage.IdentityPreimage, + keyperIndex int, +) *shcrypto.EpochSecretKeyShare { + return eonkeys.keyperShares[keyperIndex].ComputeEpochSecretKeyShare(identityPreimage) +} + func (eonkeys *EonKeys) EpochSecretKey(identityPreimage identitypreimage.IdentityPreimage) (*shcrypto.EpochSecretKey, error) { keyperIndices := []int{} + epochSecretKeyShares := []*shcrypto.EpochSecretKeyShare{} for i := uint64(0); i < eonkeys.Threshold; i++ { keyperIndices = append(keyperIndices, int(i)) + epochSecretKeyShare := eonkeys.EpochSecretKeyShare(identityPreimage, int(i)) + epochSecretKeyShares = append(epochSecretKeyShares, epochSecretKeyShare) } - epochSecretKeyShares := eonkeys.getEpochSecretKeyShares(identityPreimage, keyperIndices) return shcrypto.ComputeEpochSecretKey( keyperIndices, epochSecretKeyShares, diff --git a/rolling-shutter/medley/testkeygen/keygenerator.go b/rolling-shutter/medley/testkeygen/keygenerator.go deleted file mode 100644 index c522b70b5..000000000 --- a/rolling-shutter/medley/testkeygen/keygenerator.go +++ /dev/null @@ -1,97 +0,0 @@ -package testkeygen - -import ( - "math/rand" - "testing" - - "gotest.tools/assert" - - "github.com/shutter-network/shutter/shlib/shcrypto" - - "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/identitypreimage" -) - -// TestKeyGenerator is a helper tool to generate secret and public eon and epoch keys and key -// shares. It will generate a new eon key every eonInterval epochs. -type TestKeyGenerator struct { - t *testing.T - eonInterval uint64 - eonKeyGen map[uint64]*EonKeys - NumKeypers uint64 - Threshold uint64 -} - -func NewTestKeyGenerator(t *testing.T, numKeypers uint64, threshold uint64, infiniteInterval bool) *TestKeyGenerator { - t.Helper() - eonInterval := 100 - if infiniteInterval { - eonInterval = 0 // 0 stands for infinity - } - return &TestKeyGenerator{ - t: t, - eonInterval: uint64(eonInterval), - eonKeyGen: make(map[uint64]*EonKeys), - NumKeypers: numKeypers, - Threshold: threshold, - } -} - -// getEonIndex computes the index of the EON key to be used for the given identityPreimage. We generate a new -// eon key every eonInterval epochs. -func (tkg *TestKeyGenerator) getEonIndex(identityPreimage identitypreimage.IdentityPreimage) uint64 { - if tkg.eonInterval == 0 { - return 0 - } - - return identityPreimage.Big().Uint64() / tkg.eonInterval -} - -func (tkg *TestKeyGenerator) EonKeysForEpoch(identityPreimage identitypreimage.IdentityPreimage) *EonKeys { - tkg.t.Helper() - var err error - eonIndex := tkg.getEonIndex(identityPreimage) - res, ok := tkg.eonKeyGen[eonIndex] - if !ok { - res, err = NewEonKeys( - rand.New(rand.NewSource(int64(eonIndex))), //nolint:gosec - tkg.NumKeypers, - tkg.Threshold, - ) - assert.NilError(tkg.t, err) - tkg.eonKeyGen[eonIndex] = res - } - return res -} - -func (tkg *TestKeyGenerator) EonPublicKeyShare(identityPreimage identitypreimage.IdentityPreimage, - keyperIndex uint64, -) *shcrypto.EonPublicKeyShare { - tkg.t.Helper() - return tkg.EonKeysForEpoch(identityPreimage).keyperShares[keyperIndex].eonPublicKeyShare -} - -func (tkg *TestKeyGenerator) EonPublicKey(identityPreimage identitypreimage.IdentityPreimage) *shcrypto.EonPublicKey { - tkg.t.Helper() - return tkg.EonKeysForEpoch(identityPreimage).publicKey -} - -func (tkg *TestKeyGenerator) EonSecretKeyShare(identityPreimage identitypreimage.IdentityPreimage, - keyperIndex uint64, -) *shcrypto.EonSecretKeyShare { - tkg.t.Helper() - return tkg.EonKeysForEpoch(identityPreimage).keyperShares[keyperIndex].eonSecretKeyShare -} - -func (tkg *TestKeyGenerator) EpochSecretKeyShare(identityPreimage identitypreimage.IdentityPreimage, - keyperIndex uint64, -) *shcrypto.EpochSecretKeyShare { - tkg.t.Helper() - return tkg.EonKeysForEpoch(identityPreimage).keyperShares[keyperIndex].ComputeEpochSecretKeyShare(identityPreimage) -} - -func (tkg *TestKeyGenerator) EpochSecretKey(identityPreimage identitypreimage.IdentityPreimage) *shcrypto.EpochSecretKey { - tkg.t.Helper() - epochSecretKey, err := tkg.EonKeysForEpoch(identityPreimage).EpochSecretKey(identityPreimage) - assert.NilError(tkg.t, err) - return epochSecretKey -} diff --git a/rolling-shutter/medley/testkeygen/testgenerator.go b/rolling-shutter/medley/testkeygen/testgenerator.go deleted file mode 100644 index 4e127386c..000000000 --- a/rolling-shutter/medley/testkeygen/testgenerator.go +++ /dev/null @@ -1,93 +0,0 @@ -package testkeygen - -import ( - "crypto/rand" - - "github.com/shutter-network/shutter/shlib/shcrypto" - - "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/identitypreimage" -) - -// KeyGenerator is a helper tool to generate secret and public eon and epoch keys and key -// shares. It will generate a new eon key every eonInterval epochs. -type KeyGenerator struct { - eonInterval uint64 - eonKeyGen map[uint64]*EonKeys - NumKeypers uint64 - Threshold uint64 -} - -func NewKeyGenerator(numKeypers uint64, threshold uint64) *KeyGenerator { - return &KeyGenerator{ - eonInterval: 100, // 0 stands for infinity - eonKeyGen: make(map[uint64]*EonKeys), - NumKeypers: numKeypers, - Threshold: threshold, - } -} - -// getEonIndex computes the index of the EON key to be used for the given epochID. We generate a new -// eon key every eonInterval epochs. -func (kg *KeyGenerator) getEonIndex(epochID identitypreimage.IdentityPreimage) uint64 { - if kg.eonInterval == 0 { - return 0 - } - - return epochID.Big().Uint64() / kg.eonInterval -} - -func (kg *KeyGenerator) EonKeysForEpoch(epochID identitypreimage.IdentityPreimage) *EonKeys { - eonIndex := kg.getEonIndex(epochID) - res, ok := kg.eonKeyGen[eonIndex] - var err error - if !ok { - res, err = NewEonKeys( - rand.Reader, - kg.NumKeypers, - kg.Threshold, - ) - if err != nil { - return nil - } - kg.eonKeyGen[eonIndex] = res - } - return res -} - -func (kg *KeyGenerator) EonPublicKeyShare(epochID identitypreimage.IdentityPreimage, keyperIndex uint64) *shcrypto.EonPublicKeyShare { - return kg.EonKeysForEpoch(epochID).keyperShares[keyperIndex].eonPublicKeyShare -} - -func (kg *KeyGenerator) EonPublicKey(epochID identitypreimage.IdentityPreimage) *shcrypto.EonPublicKey { - return kg.EonKeysForEpoch(epochID).publicKey -} - -func (kg *KeyGenerator) EonSecretKeyShare(epochID identitypreimage.IdentityPreimage, keyperIndex uint64) *shcrypto.EonSecretKeyShare { - return kg.EonKeysForEpoch(epochID).keyperShares[keyperIndex].eonSecretKeyShare -} - -func (kg *KeyGenerator) EpochSecretKeyShare(epochID identitypreimage.IdentityPreimage, keyperIndex uint64) *shcrypto.EpochSecretKeyShare { - return kg.EonKeysForEpoch(epochID).keyperShares[keyperIndex].ComputeEpochSecretKeyShare(epochID) -} - -func (kg *KeyGenerator) EpochSecretKey(epochID identitypreimage.IdentityPreimage) *shcrypto.EpochSecretKey { - epochSecretKey, err := kg.EonKeysForEpoch(epochID).EpochSecretKey(epochID) - if err != nil { - panic(err) - } - return epochSecretKey -} - -func (kg *KeyGenerator) RandomEpochID(epochbytes []byte) identitypreimage.IdentityPreimage { - _, err := rand.Read(epochbytes) - if err != nil { - panic(err) - } - - epochID := identitypreimage.IdentityPreimage(epochbytes) - return epochID -} - -func (kg *KeyGenerator) RandomSigma() (shcrypto.Block, error) { - return shcrypto.RandomSigma(rand.Reader) -} diff --git a/rolling-shutter/medley/testsetup/database.go b/rolling-shutter/medley/testsetup/database.go index 9185007a2..60d0f0056 100644 --- a/rolling-shutter/medley/testsetup/database.go +++ b/rolling-shutter/medley/testsetup/database.go @@ -42,24 +42,24 @@ var testDBSuffix = "-test" // newDBPoolTeardown connects to a test db specified an environment variable and clears it from all // schemas we might have created. It returns the db connection pool and a close function. Call the // close function at the end of the test to reset the db again and close the connection. -func newDBPoolTeardown(ctx context.Context, t *testing.T) (*pgxpool.Pool, func()) { - t.Helper() +func newDBPoolTeardown(ctx context.Context, tb testing.TB) (*pgxpool.Pool, func()) { + tb.Helper() testDBURL, exists := os.LookupEnv(testDBURLVar) if !exists { - t.Skipf("no test db specified, please set %s", testDBURLVar) + tb.Fatalf("no test db specified, please set %s", testDBURLVar) } dbpool, err := pgxpool.Connect(ctx, testDBURL) if err != nil { - t.Fatalf("failed to connect to test db: %v", err) + tb.Fatalf("failed to connect to test db: %v", err) } closedb := func() { _, err = dbpool.Exec(ctx, dropEverything) dbpool.Close() // close db no matter if dropping failed if err != nil { - t.Fatalf("failed to reset test db: %v", err) + tb.Fatalf("failed to reset test db: %v", err) } } @@ -67,22 +67,22 @@ func newDBPoolTeardown(ctx context.Context, t *testing.T) (*pgxpool.Pool, func() _, err = dbpool.Exec(ctx, dropEverything) if err != nil { dbpool.Close() - t.Fatalf("failed to reset test db: %v", err) + tb.Fatalf("failed to reset test db: %v", err) } return dbpool, closedb } -func NewTestDBPool(ctx context.Context, t *testing.T, definition db.Definition) (*pgxpool.Pool, func()) { - t.Helper() +func NewTestDBPool(ctx context.Context, tb testing.TB, definition db.Definition) (*pgxpool.Pool, func()) { + tb.Helper() - dbpool, closedb := newDBPoolTeardown(ctx, t) + dbpool, closedb := newDBPoolTeardown(ctx, tb) err := db.InitDB(ctx, dbpool, definition.Name()+testDBSuffix, definition) if err != nil { log.Error().Err(err).Str("db-definition", definition.Name()).Msg("Initializing DB failed") closedb() - t.Fatalf("failed to initialize '%s' db", definition.Name()) + tb.Fatalf("failed to initialize '%s' db", definition.Name()) } return dbpool, closedb } diff --git a/rolling-shutter/medley/testsetup/eon.go b/rolling-shutter/medley/testsetup/eon.go index 94b5e04a7..951b4d32d 100644 --- a/rolling-shutter/medley/testsetup/eon.go +++ b/rolling-shutter/medley/testsetup/eon.go @@ -3,6 +3,7 @@ package testsetup import ( "context" "crypto/ecdsa" + "crypto/rand" "database/sql" "testing" @@ -17,7 +18,6 @@ import ( chainobsdb "github.com/shutter-network/rolling-shutter/rolling-shutter/chainobserver/db/collator" "github.com/shutter-network/rolling-shutter/rolling-shutter/keyper/database" "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/db" - "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/identitypreimage" "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/testkeygen" "github.com/shutter-network/rolling-shutter/rolling-shutter/shdb" ) @@ -31,15 +31,15 @@ type TestConfig interface { func InitializeEon( ctx context.Context, - t *testing.T, + tb testing.TB, dbpool *pgxpool.Pool, config TestConfig, keyperIndex uint64, -) *testkeygen.TestKeyGenerator { - t.Helper() +) *testkeygen.EonKeys { + tb.Helper() err := dbpool.BeginFunc(db.WrapContext(ctx, database.Definition.Validate)) - assert.NilError(t, err) + assert.NilError(tb, err) keyperDB := database.New(dbpool) keypers := []string{ @@ -51,55 +51,56 @@ func InitializeEon( collatorKey := config.GetCollatorKey() if collatorKey != nil { err := dbpool.BeginFunc(db.WrapContext(ctx, chainobsdb.Definition.Validate)) - assert.NilError(t, err) + assert.NilError(tb, err) chdb := chainobsdb.New(dbpool) err = chdb.InsertChainCollator(ctx, chainobsdb.InsertChainCollatorParams{ ActivationBlockNumber: 0, Collator: shdb.EncodeAddress(ethcrypto.PubkeyToAddress(config.GetCollatorKey().PublicKey)), }) - assert.NilError(t, err) + assert.NilError(tb, err) } - tkg := testkeygen.NewTestKeyGenerator(t, 3, 2, false) + eonKeys, err := testkeygen.NewEonKeys(rand.Reader, 3, 2) + assert.NilError(tb, err) + publicKeyShares := []*shcrypto.EonPublicKeyShare{} - identityPreimage := identitypreimage.BigToIdentityPreimage(common.Big0) - for i := uint64(0); i < tkg.NumKeypers; i++ { - share := tkg.EonPublicKeyShare(identityPreimage, i) + for i := 0; i < int(eonKeys.NumKeypers); i++ { + share := eonKeys.EonPublicKeyShare(i) publicKeyShares = append(publicKeyShares, share) } dkgResult := puredkg.Result{ - Eon: config.GetEon(), - NumKeypers: tkg.NumKeypers, - Threshold: tkg.Threshold, + Eon: 1, + NumKeypers: eonKeys.NumKeypers, + Threshold: eonKeys.Threshold, Keyper: keyperIndex, - SecretKeyShare: tkg.EonSecretKeyShare(identityPreimage, keyperIndex), - PublicKey: tkg.EonPublicKey(identityPreimage), + SecretKeyShare: eonKeys.EonSecretKeyShare(int(keyperIndex)), + PublicKey: eonKeys.EonPublicKey(), PublicKeyShares: publicKeyShares, } dkgResultEncoded, err := shdb.EncodePureDKGResult(&dkgResult) - assert.NilError(t, err) + assert.NilError(tb, err) err = keyperDB.InsertBatchConfig(ctx, database.InsertBatchConfigParams{ KeyperConfigIndex: 1, Height: 0, Keypers: keypers, - Threshold: int32(tkg.Threshold), + Threshold: int32(eonKeys.Threshold), }) - assert.NilError(t, err) + assert.NilError(tb, err) err = keyperDB.InsertEon(ctx, database.InsertEonParams{ Eon: int64(config.GetEon()), Height: 0, ActivationBlockNumber: 0, KeyperConfigIndex: 1, }) - assert.NilError(t, err) + assert.NilError(tb, err) err = keyperDB.InsertDKGResult(ctx, database.InsertDKGResultParams{ Eon: int64(config.GetEon()), Success: true, Error: sql.NullString{}, PureResult: dkgResultEncoded, }) - assert.NilError(t, err) + assert.NilError(tb, err) - return tkg + return eonKeys } diff --git a/rolling-shutter/p2p/p2ptest/p2ptest.go b/rolling-shutter/p2p/p2ptest/p2ptest.go index d47fe9ce3..1e01be0d9 100644 --- a/rolling-shutter/p2p/p2ptest/p2ptest.go +++ b/rolling-shutter/p2p/p2ptest/p2ptest.go @@ -16,13 +16,13 @@ import ( // MustValidateMessageResult calls the handlers ValidateMessage method and ensures it returns the // expected result. func MustValidateMessageResult( - t *testing.T, + tb testing.TB, expectedResult pubsub.ValidationResult, handler p2p.MessageHandler, ctx context.Context, //nolint:revive msg p2pmsg.Message, ) { - t.Helper() + tb.Helper() validationResult, err := handler.ValidateMessage(ctx, msg) accepted := validationResult == pubsub.ValidationAccept log.Debug(). @@ -30,21 +30,21 @@ func MustValidateMessageResult( Int("result", int(validationResult)). Int("expected", int(expectedResult)).Err(err).Msg("ValidateMessage") if accepted { - assert.NilError(t, err, "validation returned error") + assert.NilError(tb, err, "validation returned error") } - assert.Equal(t, expectedResult, validationResult, "validation did not validate with expected result ") + assert.Equal(tb, expectedResult, validationResult, "validation did not validate with expected result ") } // MustHandleMessage makes sure the handler validates and handles the given message without errors. func MustHandleMessage( - t *testing.T, + tb testing.TB, handler p2p.MessageHandler, ctx context.Context, //nolint:revive msg p2pmsg.Message, ) []p2pmsg.Message { - t.Helper() - MustValidateMessageResult(t, pubsub.ValidationAccept, handler, ctx, msg) + tb.Helper() + MustValidateMessageResult(tb, pubsub.ValidationAccept, handler, ctx, msg) msgs, err := handler.HandleMessage(ctx, msg) - assert.NilError(t, err) + assert.NilError(tb, err) return msgs } diff --git a/rolling-shutter/p2pmsg/messages_test.go b/rolling-shutter/p2pmsg/messages_test.go index def790818..5029fdaba 100644 --- a/rolling-shutter/p2pmsg/messages_test.go +++ b/rolling-shutter/p2pmsg/messages_test.go @@ -1,6 +1,7 @@ package p2pmsg import ( + "crypto/rand" "testing" "github.com/ethereum/go-ethereum/common" @@ -39,31 +40,34 @@ type testConfig struct { identityPreimage identitypreimage.IdentityPreimage blockNumber uint64 instanceID uint64 - tkg *testkeygen.TestKeyGenerator + keys *testkeygen.EonKeys } func defaultTestConfig(t *testing.T) testConfig { t.Helper() identityPreimage := identitypreimage.BigToIdentityPreimage(common.Big2) + keys, err := testkeygen.NewEonKeys(rand.Reader, 1, 1) + assert.NilError(t, err) return testConfig{ identityPreimage: identityPreimage, blockNumber: uint64(0), instanceID: uint64(42), - tkg: testkeygen.NewTestKeyGenerator(t, 1, 1, false), + keys: keys, } } func TestDecryptionKeys(t *testing.T) { cfg := defaultTestConfig(t) - validSecretKey := cfg.tkg.EpochSecretKey(cfg.identityPreimage).Marshal() + validSecretKey, err := cfg.keys.EpochSecretKey(cfg.identityPreimage) + assert.NilError(t, err) orig := &DecryptionKeys{ InstanceID: cfg.instanceID, Keys: []*Key{ { Identity: cfg.identityPreimage.Bytes(), - Key: validSecretKey, + Key: validSecretKey.Marshal(), }, }, } @@ -93,7 +97,7 @@ func TestDecryptionTrigger(t *testing.T) { func TestDecryptionKeyShare(t *testing.T) { cfg := defaultTestConfig(t) keyperIndex := uint64(0) - keyshare := cfg.tkg.EpochSecretKeyShare(cfg.identityPreimage, keyperIndex).Marshal() + keyshare := cfg.keys.EpochSecretKeyShare(cfg.identityPreimage, int(keyperIndex)).Marshal() orig := &DecryptionKeyShares{ InstanceID: cfg.instanceID, @@ -110,7 +114,7 @@ func TestDecryptionKeyShare(t *testing.T) { func TestEonPublicKey(t *testing.T) { cfg := defaultTestConfig(t) - eonPublicKey := cfg.tkg.EonPublicKey(cfg.identityPreimage).Marshal() + eonPublicKey := cfg.keys.EonPublicKey().Marshal() activationBlock := uint64(2) privKey, err := ethcrypto.GenerateKey() @@ -138,13 +142,14 @@ func TestTraceContext(t *testing.T) { TraceState: "tracestate", } cfg := defaultTestConfig(t) - validSecretKey := cfg.tkg.EpochSecretKey(cfg.identityPreimage).Marshal() + validSecretKey, err := cfg.keys.EpochSecretKey(cfg.identityPreimage) + assert.NilError(t, err) msg := &DecryptionKeys{ InstanceID: cfg.instanceID, Keys: []*Key{ { Identity: cfg.identityPreimage.Bytes(), - Key: validSecretKey, + Key: validSecretKey.Marshal(), }, }, } diff --git a/rolling-shutter/sandbox/keygen/keygen.go b/rolling-shutter/sandbox/keygen/keygen.go index 94ef6ceb6..e34b341aa 100644 --- a/rolling-shutter/sandbox/keygen/keygen.go +++ b/rolling-shutter/sandbox/keygen/keygen.go @@ -3,8 +3,8 @@ package main import ( "bytes" + "crypto/rand" "fmt" - "testing" "github.com/shutter-network/shutter/shlib/shcrypto" @@ -13,13 +13,19 @@ import ( ) func main() { - keygen := testkeygen.NewTestKeyGenerator(&testing.T{}, 3, 2, false) + keys, err := testkeygen.NewEonKeys(rand.Reader, 3, 2) + if err != nil { + panic(err) + } var prevEonPublicKey *shcrypto.EonPublicKey for i := uint64(0); i < 200; i++ { identityPreimage := identitypreimage.Uint64ToIdentityPreimage(i) - eonPublicKey := keygen.EonPublicKey(identityPreimage) - decryptionKey := keygen.EpochSecretKey(identityPreimage) + eonPublicKey := keys.EonPublicKey() + decryptionKey, err := keys.EpochSecretKey(identityPreimage) + if err != nil { + panic(err) + } if prevEonPublicKey == nil || !bytes.Equal(eonPublicKey.Marshal(), prevEonPublicKey.Marshal()) { if prevEonPublicKey != nil {