From afdcff05823a5f40fd4a58e6cbd604a6f1ba8b77 Mon Sep 17 00:00:00 2001 From: Andres Uribe Gonzalez Date: Wed, 12 Jul 2023 18:01:50 -0400 Subject: [PATCH] Supporting SQL based storage --- go.mod | 3 + go.sum | 9 +- pkg/server/router/credential_test.go | 32 +- pkg/server/router/did_test.go | 6 +- pkg/server/router/keystore_test.go | 2 +- pkg/server/router/manifest_test.go | 6 +- pkg/server/router/schema_test.go | 6 +- pkg/server/router/webhook_test.go | 2 +- pkg/server/server_did_test.go | 24 +- pkg/server/server_issuance_test.go | 12 +- pkg/server/server_keystore_test.go | 4 +- pkg/server/server_manifest_test.go | 18 +- pkg/server/server_operation_test.go | 20 +- pkg/server/server_presentation_test.go | 46 +-- pkg/server/server_schema_test.go | 8 +- pkg/server/server_webhook_test.go | 20 +- pkg/service/did/ion_test.go | 12 +- pkg/service/did/storage_test.go | 14 +- pkg/storage/db_test.go | 51 ++- pkg/storage/sql.go | 470 +++++++++++++++++++++++++ pkg/storage/storage.go | 5 +- 21 files changed, 648 insertions(+), 122 deletions(-) create mode 100644 pkg/storage/sql.go diff --git a/go.mod b/go.mod index c65e05195..9f2c9e66d 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/benbjohnson/clock v1.3.5 github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 github.com/cenkalti/backoff/v4 v4.2.1 + github.com/fergusstrange/embedded-postgres v1.23.0 github.com/gin-contrib/cors v1.4.0 github.com/gin-gonic/gin v1.9.1 github.com/go-playground/locales v0.14.1 @@ -117,6 +118,7 @@ require ( github.com/lestrrat-go/httprc v1.0.4 // indirect github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/option v1.0.1 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.19 // indirect @@ -155,6 +157,7 @@ require ( github.com/swaggo/swag v1.16.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect + github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect github.com/yuin/gopher-lua v1.1.0 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/otel/metric v1.16.0 // indirect diff --git a/go.sum b/go.sum index 35821e820..cc9d8e711 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,6 @@ github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbi github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= -github.com/TBD54566975/ssi-sdk v0.0.4-alpha.0.20230707203029-1bbf9c13be59 h1:O7GtPvUMsbJYFGx+In7ai1Sxm7gQaF7fenwlovpYbvs= -github.com/TBD54566975/ssi-sdk v0.0.4-alpha.0.20230707203029-1bbf9c13be59/go.mod h1:yPkVO9MCC/kRu+lut3jllhnCV0gEqSubaaSVT7xLSOs= github.com/TBD54566975/ssi-sdk v0.0.4-alpha.0.20230711190054-bce640c9bf25 h1:wW+49kQxN/BYcMkbDjQA9mkrUC9cYUV6HOFLP+JIx+E= github.com/TBD54566975/ssi-sdk v0.0.4-alpha.0.20230711190054-bce640c9bf25/go.mod h1:lup1EqGAT730/c7dRF9Q8OvgsjWJewq4K2dFAyBV1vk= github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= @@ -138,6 +136,8 @@ github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4 github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fergusstrange/embedded-postgres v1.23.0 h1:ZYRD89nammxQDWDi6taJE2CYjDuAoVc1TpEqRIYQryc= +github.com/fergusstrange/embedded-postgres v1.23.0/go.mod h1:wL562t1V+iuFwq0UcgMi2e9rp8CROY9wxWZEfP8Y874= github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= @@ -351,6 +351,8 @@ github.com/lestrrat-go/jwx/v2 v2.0.11/go.mod h1:ZtPtMFlrfDrH2Y0iwfa3dRFn8VzwBrB+ github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= @@ -488,6 +490,8 @@ github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6 github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= +github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -524,6 +528,7 @@ go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs= go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= +go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= diff --git a/pkg/server/router/credential_test.go b/pkg/server/router/credential_test.go index b8bdef16d..4b3e5dfd2 100644 --- a/pkg/server/router/credential_test.go +++ b/pkg/server/router/credential_test.go @@ -44,7 +44,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Credential Service Test", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) assert.NotEmpty(tt, s) serviceConfig := config.CredentialServiceConfig{BaseServiceConfig: &config.BaseServiceConfig{Name: "credential"}} @@ -201,7 +201,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Credential Service Test Revoked Key", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) assert.NotEmpty(tt, s) // Initialize services @@ -264,7 +264,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Credential Status List Test", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) assert.NotEmpty(tt, s) serviceConfig := config.CredentialServiceConfig{BaseServiceConfig: &config.BaseServiceConfig{Name: "credential", ServiceEndpoint: "v1/credentials"}} @@ -373,7 +373,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Credential Status List Test No Schemas", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) assert.NotEmpty(tt, s) serviceConfig := config.CredentialServiceConfig{BaseServiceConfig: &config.BaseServiceConfig{Name: "credential", ServiceEndpoint: "/v1/credentials"}} @@ -477,7 +477,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Credential Status List Test Update Revoked Status", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) assert.NotEmpty(tt, s) serviceConfig := config.CredentialServiceConfig{BaseServiceConfig: &config.BaseServiceConfig{Name: "credential", ServiceEndpoint: "http://localhost:1234/v1/credentials"}} @@ -597,7 +597,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Credential Status List Test Update Suspended Status", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) assert.NotEmpty(tt, s) serviceConfig := config.CredentialServiceConfig{BaseServiceConfig: &config.BaseServiceConfig{Name: "credential", ServiceEndpoint: "http://localhost:1234/v1/credentials"}} @@ -717,7 +717,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Create Multiple Suspendable Credential Different IssuerDID SchemaID StatusPurpose Triples", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) assert.NotEmpty(tt, s) serviceConfig := config.CredentialServiceConfig{BaseServiceConfig: &config.BaseServiceConfig{Name: "credential", ServiceEndpoint: "http://localhost:1234/v1/credentials"}} @@ -801,7 +801,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Create Suspendable Credential", func(tt *testing.T) { - issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(t)) + issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(tt)) subject := "did:test:345" createdCred, err := credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{ @@ -857,7 +857,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Update Suspendable Credential To Suspended", func(tt *testing.T) { - issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(t)) + issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(tt)) subject := "did:test:345" createdCred, err := credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{ @@ -935,7 +935,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Update Suspendable Credential To Suspended then Unsuspended", func(tt *testing.T) { - issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(t)) + issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(tt)) subject := "did:test:345" createdCred, err := credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{ @@ -1018,7 +1018,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Create Suspendable and Revocable Credential Should Be Error", func(tt *testing.T) { - issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(t)) + issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(tt)) subject := "did:test:345" createdCred, err := credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{ @@ -1040,7 +1040,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Update Suspendable and Revocable Credential Should Be Error", func(tt *testing.T) { - issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(t)) + issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(tt)) subject := "did:test:345" createdCred, err := credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{ @@ -1065,7 +1065,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Update Suspended On Revoked Credential Should Be Error", func(tt *testing.T) { - issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(t)) + issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(tt)) subject := "did:test:345" createdCred, err := credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{ @@ -1090,7 +1090,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Create Credential With Invalid Evidence", func(tt *testing.T) { - issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(t)) + issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(tt)) subject := "did:test:345" _, err := credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{ @@ -1109,7 +1109,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Create Credential With Invalid Evidence No Id", func(tt *testing.T) { - issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(t)) + issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(tt)) subject := "did:test:345" evidenceMap := map[string]any{ @@ -1137,7 +1137,7 @@ func TestCredentialRouter(t *testing.T) { }) t.Run("Create Credential With Evidence", func(tt *testing.T) { - issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(t)) + issuer, verificationMethodID, schemaID, credService := createCredServicePrereqs(tt, test.ServiceStorage(tt)) subject := "did:test:345" createdCred, err := credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{ diff --git a/pkg/server/router/did_test.go b/pkg/server/router/did_test.go index ce466b38f..8952ad9f5 100644 --- a/pkg/server/router/did_test.go +++ b/pkg/server/router/did_test.go @@ -36,7 +36,7 @@ func TestDIDRouter(t *testing.T) { // TODO: Fix pagesize issue on redis - https://github.com/TBD54566975/ssi-service/issues/538 if !strings.Contains(test.Name, "Redis") { t.Run("List DIDs supports paging", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) assert.NotEmpty(tt, db) keyStoreService := testKeyStoreService(tt, db) methods := []string{didsdk.KeyMethod.String()} @@ -77,7 +77,7 @@ func TestDIDRouter(t *testing.T) { } t.Run("DID Service Test", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) assert.NotEmpty(tt, db) keyStoreService := testKeyStoreService(tt, db) @@ -164,7 +164,7 @@ func TestDIDRouter(t *testing.T) { }) t.Run("DID Web Service Test", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) assert.NotEmpty(tt, db) keyStoreService := testKeyStoreService(tt, db) diff --git a/pkg/server/router/keystore_test.go b/pkg/server/router/keystore_test.go index 51b53870f..6db8efbdb 100644 --- a/pkg/server/router/keystore_test.go +++ b/pkg/server/router/keystore_test.go @@ -33,7 +33,7 @@ func TestKeyStoreRouter(t *testing.T) { for _, test := range testutil.TestDatabases { t.Run(test.Name, func(t *testing.T) { t.Run("Key Store Service Test", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) assert.NotEmpty(tt, db) serviceConfig := config.KeyStoreServiceConfig{ diff --git a/pkg/server/router/manifest_test.go b/pkg/server/router/manifest_test.go index 2b92c7ed8..db2a505cb 100644 --- a/pkg/server/router/manifest_test.go +++ b/pkg/server/router/manifest_test.go @@ -46,7 +46,7 @@ func TestManifestRouter(t *testing.T) { for _, test := range testutil.TestDatabases { t.Run(test.Name, func(t *testing.T) { t.Run("Manifest Service Test", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) assert.NotEmpty(tt, db) keyStoreService := testKeyStoreService(tt, db) @@ -93,10 +93,10 @@ func TestManifestRouter(t *testing.T) { createManifestRequest.PresentationDefinitionRef = &model.PresentationDefinitionRef{ ID: &resp.PresentationDefinition.ID, } - manifest, err := manifestService.CreateManifest(context.Background(), createManifestRequest) + createdManifest, err := manifestService.CreateManifest(context.Background(), createManifestRequest) assert.NoError(ttt, err) - assert.Equal(ttt, resp.PresentationDefinition, *manifest.Manifest.PresentationDefinition) + assert.Equal(ttt, resp.PresentationDefinition, *createdManifest.Manifest.PresentationDefinition) }) tt.Run("multiple behaviors", func(ttt *testing.T) { diff --git a/pkg/server/router/schema_test.go b/pkg/server/router/schema_test.go index 77e91f92b..de957dbd9 100644 --- a/pkg/server/router/schema_test.go +++ b/pkg/server/router/schema_test.go @@ -37,7 +37,7 @@ func TestSchemaRouter(t *testing.T) { t.Run(test.Name, func(t *testing.T) { t.Run("Schema Service Test", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) assert.NotEmpty(tt, db) serviceConfig := config.SchemaServiceConfig{BaseServiceConfig: &config.BaseServiceConfig{Name: "schema"}} @@ -117,7 +117,7 @@ func TestSchemaSigning(t *testing.T) { for _, test := range testutil.TestDatabases { t.Run(test.Name, func(t *testing.T) { t.Run("Unsigned Schema Test", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) assert.NotEmpty(tt, db) serviceConfig := config.SchemaServiceConfig{BaseServiceConfig: &config.BaseServiceConfig{Name: "schema"}} @@ -142,7 +142,7 @@ func TestSchemaSigning(t *testing.T) { }) t.Run("Signing schema with revoked key test", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) assert.NotEmpty(tt, db) serviceConfig := config.SchemaServiceConfig{BaseServiceConfig: &config.BaseServiceConfig{Name: "schema"}} diff --git a/pkg/server/router/webhook_test.go b/pkg/server/router/webhook_test.go index 2dbce6ab3..15ea44141 100644 --- a/pkg/server/router/webhook_test.go +++ b/pkg/server/router/webhook_test.go @@ -30,7 +30,7 @@ func TestWebhookRouter(t *testing.T) { for _, test := range testutil.TestDatabases { t.Run(test.Name, func(t *testing.T) { t.Run("Webhook Service Test", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) serviceConfig := config.WebhookServiceConfig{WebhookTimeout: "10s"} diff --git a/pkg/server/server_did_test.go b/pkg/server/server_did_test.go index 01f4c545f..b2815aaa6 100644 --- a/pkg/server/server_did_test.go +++ b/pkg/server/server_did_test.go @@ -31,7 +31,7 @@ func TestDIDAPI(t *testing.T) { t.Run(test.Name, func(t *testing.T) { t.Run("Test Get DID Methods", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) _, keyStoreService, _ := testKeyStore(tt, db) @@ -56,7 +56,7 @@ func TestDIDAPI(t *testing.T) { }) t.Run("Test Create DID By Method: Key", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) _, keyStoreService, _ := testKeyStore(tt, db) @@ -101,7 +101,7 @@ func TestDIDAPI(t *testing.T) { }) t.Run("Test Create DID By Method: Web", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) _, keyStoreService, _ := testKeyStore(tt, db) @@ -174,7 +174,7 @@ func TestDIDAPI(t *testing.T) { }) t.Run("Test Create DID By Method: ION", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) _, keyStoreService, _ := testKeyStore(tt, db) @@ -252,7 +252,7 @@ func TestDIDAPI(t *testing.T) { }) t.Run("Test Create Duplicate DID:Webs", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) _, keyStoreService, _ := testKeyStore(tt, db) @@ -312,7 +312,7 @@ func TestDIDAPI(t *testing.T) { }) t.Run("Test Get DID By Method", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) _, keyStore, _ := testKeyStore(tt, db) @@ -384,7 +384,7 @@ func TestDIDAPI(t *testing.T) { }) t.Run("Test Soft Delete DID By Method", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) _, keyStore, _ := testKeyStore(tt, db) @@ -509,7 +509,7 @@ func TestDIDAPI(t *testing.T) { }) t.Run("List DIDs made up token fails", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) _, keyStore, _ := testKeyStore(tt, db) didService, _ := testDIDRouter(tt, db, keyStore, []string{"key", "web"}, nil) @@ -528,7 +528,7 @@ func TestDIDAPI(t *testing.T) { t.Run("List DIDs pagination", func(tt *testing.T) { if !strings.Contains(test.Name, "Redis") { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) _, keyStore, _ := testKeyStore(tt, db) didRouter, _ := testDIDRouter(tt, db, keyStore, []string{"key", "web"}, nil) @@ -569,7 +569,7 @@ func TestDIDAPI(t *testing.T) { t.Run("List DIDs pagination change query between calls returns error", func(tt *testing.T) { if !strings.Contains(test.Name, "Redis") { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) _, keyStore, _ := testKeyStore(tt, db) didRouter, _ := testDIDRouter(tt, db, keyStore, []string{"key", "web"}, nil) @@ -605,7 +605,7 @@ func TestDIDAPI(t *testing.T) { }) t.Run("Test Get DIDs By Method", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) _, keyStore, _ := testKeyStore(tt, db) didService, _ := testDIDRouter(tt, db, keyStore, []string{"key", "web"}, nil) @@ -690,7 +690,7 @@ func TestDIDAPI(t *testing.T) { }) t.Run("Test Resolve DIDs", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) _, keyStore, _ := testKeyStore(tt, db) diff --git a/pkg/server/server_issuance_test.go b/pkg/server/server_issuance_test.go index 8726bcecc..27d1879af 100644 --- a/pkg/server/server_issuance_test.go +++ b/pkg/server/server_issuance_test.go @@ -34,7 +34,7 @@ func TestIssuanceRouter(t *testing.T) { for _, test := range testutil.TestDatabases { t.Run(test.Name, func(t *testing.T) { t.Run("CreateIssuanceTemplate", func(tt *testing.T) { - issuerResp, createdSchema, manifest, r := setupAllThings(tt, test.ServiceStorage(t)) + issuerResp, createdSchema, manifest, r := setupAllThings(tt, test.ServiceStorage(tt)) for _, tc := range []struct { name string request router.CreateIssuanceTemplateRequest @@ -103,7 +103,7 @@ func TestIssuanceRouter(t *testing.T) { }) t.Run("CreateIssuanceTemplate returns error", func(tt *testing.T) { - issuerResp, createdSchema, manifest, r := setupAllThings(tt, test.ServiceStorage(t)) + issuerResp, createdSchema, manifest, r := setupAllThings(tt, test.ServiceStorage(tt)) for _, tc := range []struct { name string @@ -244,7 +244,7 @@ func TestIssuanceRouter(t *testing.T) { }) t.Run("Create, Get, Delete work as expected", func(tt *testing.T) { - issuerResp, createdSchema, manifest, r := setupAllThings(tt, test.ServiceStorage(t)) + issuerResp, createdSchema, manifest, r := setupAllThings(tt, test.ServiceStorage(tt)) inputTemplate := issuance.Template{ CredentialManifest: manifest.Manifest.ID, @@ -318,7 +318,7 @@ func TestIssuanceRouter(t *testing.T) { }) t.Run("GetIssuanceTemplate returns error for unknown ID", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) r := testIssuanceRouter(tt, s) value := newRequestValue(tt, nil) @@ -330,7 +330,7 @@ func TestIssuanceRouter(t *testing.T) { }) t.Run("ListIssuanceTemplates returns empty when there aren't templates", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) r := testIssuanceRouter(tt, s) value := newRequestValue(tt, nil) @@ -346,7 +346,7 @@ func TestIssuanceRouter(t *testing.T) { }) t.Run("ListIssuanceTemplates returns all created templates", func(tt *testing.T) { - issuerResp, createdSchema, manifest, r := setupAllThings(tt, test.ServiceStorage(t)) + issuerResp, createdSchema, manifest, r := setupAllThings(tt, test.ServiceStorage(tt)) createSimpleTemplate(tt, manifest, issuerResp, createdSchema, now, r) createSimpleTemplate(tt, manifest, issuerResp, createdSchema, now, r) diff --git a/pkg/server/server_keystore_test.go b/pkg/server/server_keystore_test.go index 20ab159a8..252d6bc53 100644 --- a/pkg/server/server_keystore_test.go +++ b/pkg/server/server_keystore_test.go @@ -22,7 +22,7 @@ func TestKeyStoreAPI(t *testing.T) { for _, test := range testutil.TestDatabases { t.Run(test.Name, func(t *testing.T) { t.Run("Test Store Key", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) keyStoreRouter, _, _ := testKeyStore(tt, db) @@ -66,7 +66,7 @@ func TestKeyStoreAPI(t *testing.T) { }) t.Run("Test Get Key Details", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) keyStoreService, _, _ := testKeyStore(tt, db) diff --git a/pkg/server/server_manifest_test.go b/pkg/server/server_manifest_test.go index f2c5e2aef..3e50bc274 100644 --- a/pkg/server/server_manifest_test.go +++ b/pkg/server/server_manifest_test.go @@ -37,7 +37,7 @@ func TestManifestAPI(t *testing.T) { for _, test := range testutil.TestDatabases { t.Run(test.Name, func(t *testing.T) { t.Run("Test Create Manifest", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) keyStoreService, _ := testKeyStoreService(tt, db) @@ -110,7 +110,7 @@ func TestManifestAPI(t *testing.T) { }) t.Run("Test Get Manifest By ID", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) keyStoreService, _ := testKeyStoreService(tt, db) @@ -180,7 +180,7 @@ func TestManifestAPI(t *testing.T) { }) t.Run("Test Get Manifests", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) keyStoreService, _ := testKeyStoreService(tt, db) @@ -268,7 +268,7 @@ func TestManifestAPI(t *testing.T) { }) t.Run("Test Delete Manifest", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) keyStoreService, _ := testKeyStoreService(tt, db) @@ -337,7 +337,7 @@ func TestManifestAPI(t *testing.T) { }) t.Run("Submit Application With Issuance Template", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) keyStoreService, _ := testKeyStoreService(tt, db) @@ -490,7 +490,7 @@ func TestManifestAPI(t *testing.T) { }) t.Run("Test Submit Application with multiple outputs and overrides", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) keyStoreService, _ := testKeyStoreService(tt, db) @@ -655,7 +655,7 @@ func TestManifestAPI(t *testing.T) { }) t.Run("Test Denied Application", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) keyStoreService, _ := testKeyStoreService(tt, db) @@ -804,7 +804,7 @@ func TestManifestAPI(t *testing.T) { }) t.Run("Test Get Application By ID and Get Applications", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) keyStoreService, _ := testKeyStoreService(tt, db) @@ -989,7 +989,7 @@ func TestManifestAPI(t *testing.T) { }) t.Run("Test Delete Application", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) keyStoreService, _ := testKeyStoreService(tt, db) diff --git a/pkg/server/server_operation_test.go b/pkg/server/server_operation_test.go index 5bfc32157..b922caa82 100644 --- a/pkg/server/server_operation_test.go +++ b/pkg/server/server_operation_test.go @@ -25,7 +25,7 @@ func TestOperationsAPI(t *testing.T) { for _, test := range testutil.TestDatabases { t.Run(test.Name, func(t *testing.T) { t.Run("Marks operation as done after reviewing submission", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) pRouter, didService := setupPresentationRouter(tt, s) authorDID := createDID(tt, didService) opRouter := setupOperationsRouter(tt, s) @@ -57,7 +57,7 @@ func TestOperationsAPI(t *testing.T) { t.Run("GetOperation", func(tt *testing.T) { tt.Run("Returns operation after submission", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) opRouter := setupOperationsRouter(ttt, s) @@ -84,7 +84,7 @@ func TestOperationsAPI(t *testing.T) { }) tt.Run("Returns error when id doesn't exist", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) opRouter := setupOperationsRouter(ttt, s) req := httptest.NewRequest(http.MethodPut, "https://ssi-service.com/v1/operations/some_fake_id", nil) @@ -98,7 +98,7 @@ func TestOperationsAPI(t *testing.T) { t.Run("ListOperations", func(tt *testing.T) { tt.Run("Returns empty when no operations stored", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) opRouter := setupOperationsRouter(ttt, s) query := url.QueryEscape("presentations/submissions") @@ -115,7 +115,7 @@ func TestOperationsAPI(t *testing.T) { }) tt.Run("Returns one operation for every submission", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) opRouter := setupOperationsRouter(ttt, s) @@ -150,7 +150,7 @@ func TestOperationsAPI(t *testing.T) { }) tt.Run("Returns operation when filtering to include", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) opRouter := setupOperationsRouter(ttt, s) @@ -178,7 +178,7 @@ func TestOperationsAPI(t *testing.T) { if !strings.Contains(test.Name, "Redis") { tt.Run("Returns zero operations when filtering to exclude", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) opRouter := setupOperationsRouter(ttt, s) @@ -207,7 +207,7 @@ func TestOperationsAPI(t *testing.T) { if !strings.Contains(test.Name, "Redis") { tt.Run("Returns zero operations when wrong parent is specified", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) opRouter := setupOperationsRouter(ttt, s) @@ -234,7 +234,7 @@ func TestOperationsAPI(t *testing.T) { t.Run("CancelOperation", func(tt *testing.T) { tt.Run("Marks an operation as done", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) opRouter := setupOperationsRouter(ttt, s) @@ -259,7 +259,7 @@ func TestOperationsAPI(t *testing.T) { }) tt.Run("Returns error when operation is done already", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) opRouter := setupOperationsRouter(ttt, s) diff --git a/pkg/server/server_presentation_test.go b/pkg/server/server_presentation_test.go index 77ac15c63..459db3d01 100644 --- a/pkg/server/server_presentation_test.go +++ b/pkg/server/server_presentation_test.go @@ -58,7 +58,7 @@ func TestPresentationAPI(t *testing.T) { for _, test := range testutil.TestDatabases { t.Run(test.Name, func(t *testing.T) { t.Run("Create, Get, and Delete PresentationDefinition", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) pRouter, _ := setupPresentationRouter(tt, s) var createdID string @@ -121,7 +121,7 @@ func TestPresentationAPI(t *testing.T) { }) t.Run("List presentation requests returns empty", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) pRouter, _ := setupPresentationRouter(tt, s) req := httptest.NewRequest(http.MethodGet, "https://ssi-service.com/v1/presentations/requests", nil) @@ -136,7 +136,7 @@ func TestPresentationAPI(t *testing.T) { }) t.Run("Get presentation requests returns created request", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) pRouter, didService := setupPresentationRouter(tt, s) issuerDID := createDID(tt, didService) def := createPresentationDefinition(tt, pRouter) @@ -158,7 +158,7 @@ func TestPresentationAPI(t *testing.T) { }) t.Run("List presentation requests returns many requests", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) pRouter, didService := setupPresentationRouter(tt, s) issuerDID := createDID(tt, didService) def := createPresentationDefinition(tt, pRouter) @@ -182,7 +182,7 @@ func TestPresentationAPI(t *testing.T) { }) t.Run("List definitions returns empty", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) pRouter, _ := setupPresentationRouter(tt, s) req := httptest.NewRequest(http.MethodGet, "https://ssi-service.com/v1/presentations/definitions", nil) @@ -197,7 +197,7 @@ func TestPresentationAPI(t *testing.T) { }) t.Run("List definitions returns many definitions", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) pRouter, _ := setupPresentationRouter(tt, s) def1 := createPresentationDefinition(tt, pRouter) def2 := createPresentationDefinition(tt, pRouter) @@ -218,7 +218,7 @@ func TestPresentationAPI(t *testing.T) { }) t.Run("Create returns error without input descriptors", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) pRouter, _ := setupPresentationRouter(tt, s) request := router.CreatePresentationDefinitionRequest{} value := newRequestValue(tt, request) @@ -231,7 +231,7 @@ func TestPresentationAPI(t *testing.T) { }) t.Run("Get without an ID returns error", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) pRouter, _ := setupPresentationRouter(tt, s) req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("https://ssi-service.com/v1/presentations/definitions/%s", pd.ID), nil) w := httptest.NewRecorder() @@ -242,7 +242,7 @@ func TestPresentationAPI(t *testing.T) { }) t.Run("Delete without an ID returns error", func(tt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) pRouter, _ := setupPresentationRouter(tt, s) req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("https://ssi-service.com/v1/presentations/definitions/%s", pd.ID), nil) @@ -254,7 +254,7 @@ func TestPresentationAPI(t *testing.T) { t.Run("Submission endpoints", func(tt *testing.T) { tt.Run("Get non-existing ID returns error", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, _ := setupPresentationRouter(ttt, s) req := httptest.NewRequest(http.MethodGet, "https://ssi-service.com/v1/presentations/submissions/myrandomid", nil) @@ -265,7 +265,7 @@ func TestPresentationAPI(t *testing.T) { }) tt.Run("Get returns submission after creation", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) @@ -295,7 +295,7 @@ func TestPresentationAPI(t *testing.T) { }) tt.Run("Create well formed submission returns operation", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) @@ -326,7 +326,7 @@ func TestPresentationAPI(t *testing.T) { }) tt.Run("Review submission returns approved submission", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) @@ -360,7 +360,7 @@ func TestPresentationAPI(t *testing.T) { }) tt.Run("Review submission twice fails", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) @@ -387,7 +387,7 @@ func TestPresentationAPI(t *testing.T) { }) tt.Run("List submissions returns empty when there are none", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, _ := setupPresentationRouter(ttt, s) req := httptest.NewRequest(http.MethodGet, "https://ssi-service.com/v1/presentations/submissions", nil) @@ -403,7 +403,7 @@ func TestPresentationAPI(t *testing.T) { }) tt.Run("List submissions invalid page size fails", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, _ := setupPresentationRouter(ttt, s) w := httptest.NewRecorder() @@ -418,7 +418,7 @@ func TestPresentationAPI(t *testing.T) { }) tt.Run("List submissions made up token fails", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, _ := setupPresentationRouter(ttt, s) w := httptest.NewRecorder() @@ -438,7 +438,7 @@ func TestPresentationAPI(t *testing.T) { if strings.Contains(test.Name, "Redis") { ttt.Skip("skipping pagination test for Redis") } - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) @@ -497,7 +497,7 @@ func TestPresentationAPI(t *testing.T) { if strings.Contains(test.Name, "Redis") { ttt.Skip("skipping pagination test for Redis") } - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) @@ -549,7 +549,7 @@ func TestPresentationAPI(t *testing.T) { }) tt.Run("List submissions returns many submissions", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) @@ -634,7 +634,7 @@ func TestPresentationAPI(t *testing.T) { }) tt.Run("bad filter returns error", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, _ := setupPresentationRouter(ttt, s) query := url.QueryEscape("im a baaad filter that's trying to break a lot of stuff") @@ -647,7 +647,7 @@ func TestPresentationAPI(t *testing.T) { }) tt.Run("List submissions filters based on status", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) @@ -700,7 +700,7 @@ func TestPresentationAPI(t *testing.T) { }) tt.Run("List submissions filter returns empty when status does not match", func(ttt *testing.T) { - s := test.ServiceStorage(t) + s := test.ServiceStorage(ttt) pRouter, didService := setupPresentationRouter(ttt, s) authorDID := createDID(ttt, didService) diff --git a/pkg/server/server_schema_test.go b/pkg/server/server_schema_test.go index 2292b5c94..ef2f0e6ee 100644 --- a/pkg/server/server_schema_test.go +++ b/pkg/server/server_schema_test.go @@ -24,7 +24,7 @@ func TestSchemaAPI(t *testing.T) { for _, test := range testutil.TestDatabases { t.Run(test.Name, func(t *testing.T) { t.Run("Test Create JsonSchema2023 Schema", func(tt *testing.T) { - bolt := test.ServiceStorage(t) + bolt := test.ServiceStorage(tt) require.NotEmpty(tt, bolt) keyStoreService, _ := testKeyStoreService(tt, bolt) @@ -66,7 +66,7 @@ func TestSchemaAPI(t *testing.T) { }) t.Run("Test Create CredentialSchema2023 Schema", func(tt *testing.T) { - bolt := test.ServiceStorage(t) + bolt := test.ServiceStorage(tt) require.NotEmpty(tt, bolt) keyStoreService, _ := testKeyStoreService(tt, bolt) @@ -143,7 +143,7 @@ func TestSchemaAPI(t *testing.T) { }) t.Run("Test Get Schema and Get Schemas", func(tt *testing.T) { - bolt := test.ServiceStorage(t) + bolt := test.ServiceStorage(tt) require.NotEmpty(tt, bolt) keyStoreService, _ := testKeyStoreService(tt, bolt) @@ -235,7 +235,7 @@ func TestSchemaAPI(t *testing.T) { }) t.Run("Test Delete Schema", func(tt *testing.T) { - bolt := test.ServiceStorage(t) + bolt := test.ServiceStorage(tt) require.NotEmpty(tt, bolt) keyStoreService, _ := testKeyStoreService(tt, bolt) diff --git a/pkg/server/server_webhook_test.go b/pkg/server/server_webhook_test.go index 8c12609b4..ec6f15300 100644 --- a/pkg/server/server_webhook_test.go +++ b/pkg/server/server_webhook_test.go @@ -143,7 +143,7 @@ func TestWebhookAPI(t *testing.T) { for _, test := range testutil.TestDatabases { t.Run(test.Name, func(t *testing.T) { t.Run("CreateWebhook returns error when missing request", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) webhookRouter := testWebhookRouter(tt, db) @@ -162,7 +162,7 @@ func TestWebhookAPI(t *testing.T) { }) t.Run("CreateWebhook returns error when verb is not supported", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) webhookRouter := testWebhookRouter(tt, db) @@ -183,7 +183,7 @@ func TestWebhookAPI(t *testing.T) { }) t.Run("CreateWebhook returns error when url is not supported", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) webhookRouter := testWebhookRouter(tt, db) @@ -204,7 +204,7 @@ func TestWebhookAPI(t *testing.T) { }) t.Run("CreateWebhook returns error when url is is missing scheme", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) webhookRouter := testWebhookRouter(tt, db) @@ -225,7 +225,7 @@ func TestWebhookAPI(t *testing.T) { }) t.Run("CreateWebhook returns valid response", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) webhookRouter := testWebhookRouter(tt, db) @@ -246,7 +246,7 @@ func TestWebhookAPI(t *testing.T) { }) t.Run("Test Happy Path Delete Webhook", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) webhookRouter := testWebhookRouter(tt, db) @@ -309,7 +309,7 @@ func TestWebhookAPI(t *testing.T) { }) t.Run("GetWebhook Throws Error When Webhook None Exist", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) webhookService := testWebhookService(tt, db) @@ -320,7 +320,7 @@ func TestWebhookAPI(t *testing.T) { }) t.Run("GetWebhook Returns Webhook That Does Exist", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) webhookService := testWebhookService(tt, db) @@ -347,7 +347,7 @@ func TestWebhookAPI(t *testing.T) { }) t.Run("Test Get Webhooks", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) webhookService := testWebhookService(tt, db) @@ -383,7 +383,7 @@ func TestWebhookAPI(t *testing.T) { }) t.Run("Test Delete Webhook", func(tt *testing.T) { - db := test.ServiceStorage(t) + db := test.ServiceStorage(tt) require.NotEmpty(tt, db) webhookService := testWebhookService(tt, db) diff --git a/pkg/service/did/ion_test.go b/pkg/service/did/ion_test.go index 6fdd0387e..7612d9275 100644 --- a/pkg/service/did/ion_test.go +++ b/pkg/service/did/ion_test.go @@ -32,7 +32,7 @@ func TestIONHandler(t *testing.T) { assert.Empty(tt, handler) assert.Contains(tt, err.Error(), "baseURL cannot be empty") - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) keystoreService := testKeyStoreService(tt, s) didStorage, err := NewDIDStorage(s) assert.NoError(tt, err) @@ -60,7 +60,7 @@ func TestIONHandler(t *testing.T) { t.Run("Create DID", func(tt *testing.T) { // create a handler - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) keystoreService := testKeyStoreService(tt, s) didStorage, err := NewDIDStorage(s) assert.NoError(tt, err) @@ -132,7 +132,7 @@ func TestIONHandler(t *testing.T) { defer gock.Off() // create a handler - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) keystoreService := testKeyStoreService(tt, s) didStorage, err := NewDIDStorage(s) assert.NoError(tt, err) @@ -159,7 +159,7 @@ func TestIONHandler(t *testing.T) { t.Run("Get DID from storage", func(tt *testing.T) { // create a handler - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) keystoreService := testKeyStoreService(tt, s) didStorage, err := NewDIDStorage(s) assert.NoError(tt, err) @@ -198,7 +198,7 @@ func TestIONHandler(t *testing.T) { t.Run("Get DIDs from storage", func(tt *testing.T) { // create a handler - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) keystoreService := testKeyStoreService(tt, s) didStorage, err := NewDIDStorage(s) assert.NoError(tt, err) @@ -256,7 +256,7 @@ func TestIONHandler(t *testing.T) { t.Run("Get DID from resolver", func(tt *testing.T) { // create a handler - s := test.ServiceStorage(t) + s := test.ServiceStorage(tt) keystoreService := testKeyStoreService(tt, s) didStorage, err := NewDIDStorage(s) assert.NoError(tt, err) diff --git a/pkg/service/did/storage_test.go b/pkg/service/did/storage_test.go index b5abd57b0..4491c9f44 100644 --- a/pkg/service/did/storage_test.go +++ b/pkg/service/did/storage_test.go @@ -14,7 +14,7 @@ func TestStorage(t *testing.T) { t.Run(test.Name, func(t *testing.T) { t.Run("Create bad DID - no namespace", func(tt *testing.T) { - ds, err := NewDIDStorage(test.ServiceStorage(t)) + ds, err := NewDIDStorage(test.ServiceStorage(tt)) assert.NoError(tt, err) // create a did @@ -33,7 +33,7 @@ func TestStorage(t *testing.T) { }) t.Run("Get bad DID - namespace does not exist", func(tt *testing.T) { - ds, err := NewDIDStorage(test.ServiceStorage(t)) + ds, err := NewDIDStorage(test.ServiceStorage(tt)) assert.NoError(tt, err) // store @@ -44,7 +44,7 @@ func TestStorage(t *testing.T) { }) t.Run("Get bad DID - does not exist", func(tt *testing.T) { - ds, err := NewDIDStorage(test.ServiceStorage(t)) + ds, err := NewDIDStorage(test.ServiceStorage(tt)) assert.NoError(tt, err) // store @@ -55,7 +55,7 @@ func TestStorage(t *testing.T) { }) t.Run("Create and Get DID", func(tt *testing.T) { - ds, err := NewDIDStorage(test.ServiceStorage(t)) + ds, err := NewDIDStorage(test.ServiceStorage(tt)) assert.NoError(tt, err) // create a did @@ -84,7 +84,7 @@ func TestStorage(t *testing.T) { }) t.Run("Create and Get DID of a custom type", func(tt *testing.T) { - ds, err := NewDIDStorage(test.ServiceStorage(t)) + ds, err := NewDIDStorage(test.ServiceStorage(tt)) assert.NoError(tt, err) // create a did @@ -110,7 +110,7 @@ func TestStorage(t *testing.T) { }) t.Run("Create and Get Multiple DIDs", func(tt *testing.T) { - ds, err := NewDIDStorage(test.ServiceStorage(t)) + ds, err := NewDIDStorage(test.ServiceStorage(tt)) assert.NoError(tt, err) // create two dids @@ -153,7 +153,7 @@ func TestStorage(t *testing.T) { }) t.Run("Soft delete DID", func(tt *testing.T) { - ds, err := NewDIDStorage(test.ServiceStorage(t)) + ds, err := NewDIDStorage(test.ServiceStorage(tt)) assert.NoError(tt, err) // create two dids diff --git a/pkg/storage/db_test.go b/pkg/storage/db_test.go index 303fd972a..5c8980b53 100644 --- a/pkg/storage/db_test.go +++ b/pkg/storage/db_test.go @@ -2,22 +2,33 @@ package storage import ( "context" + "crypto/rand" + "encoding/binary" "fmt" "os" + "path/filepath" + "strconv" "testing" "github.com/alicebob/miniredis/v2" + embeddedpostgres "github.com/fergusstrange/embedded-postgres" "github.com/goccy/go-json" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func getDBImplementations(t *testing.T) []ServiceStorage { + dbImpls := make([]ServiceStorage, 0) + boltDB := setupBoltDB(t) + dbImpls = append(dbImpls, boltDB) + redisDB := setupRedisDB(t) + dbImpls = append(dbImpls, redisDB) + + postgresDB := setupPostgresDB(t) + dbImpls = append(dbImpls, postgresDB) - dbImpls := make([]ServiceStorage, 0) - dbImpls = append(dbImpls, boltDB, redisDB) return dbImpls } @@ -37,6 +48,41 @@ func setupBoltDB(t *testing.T) *BoltDB { return db.(*BoltDB) } +func setupPostgresDB(t *testing.T) *SQLDB { + homeDir, err := os.UserHomeDir() + require.NoError(t, err) + + scalar := make([]byte, 32) + _, err = rand.Read(scalar) + require.NoError(t, err) + + randomDir := strconv.Itoa(int(binary.BigEndian.Uint32(scalar))) + postgres := embeddedpostgres.NewDatabase(embeddedpostgres.DefaultConfig(). + BinariesPath(filepath.Join(homeDir, ".embedded-postgres-go", "tmpBin")). + DataPath(filepath.Join(os.TempDir(), ".embedded-postgres-go", "data", randomDir)). + RuntimePath(filepath.Join(os.TempDir(), ".embedded-postgres-go", "runtime", randomDir))) + err = postgres.Start() + require.NoError(t, err) + + t.Cleanup(func() { + _ = postgres.Stop() + }) + + options := []Option{ + { + ID: SQLConnectionString, + Option: "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable", + }, + { + ID: SQLDriverName, + Option: "postgres", + }, + } + s, err := NewStorage(DatabaseSQL, options...) + require.NoError(t, err) + return s.(*SQLDB) +} + func setupRedisDB(t *testing.T) *RedisDB { server := miniredis.RunT(t) options := []Option{ @@ -129,6 +175,7 @@ func TestDB(t *testing.T) { // delete a namespace that doesn't exist err = db.DeleteNamespace(context.Background(), "bad") + assert.Error(t, err) assert.Contains(t, err.Error(), "could not delete namespace") // delete namespace diff --git a/pkg/storage/sql.go b/pkg/storage/sql.go new file mode 100644 index 000000000..19f179c73 --- /dev/null +++ b/pkg/storage/sql.go @@ -0,0 +1,470 @@ +package storage + +import ( + "context" + "database/sql" + "encoding/base64" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +func init() { + if err := RegisterStorage(new(SQLDB)); err != nil { + panic(err) + } +} + +const ( + SQLConnectionString OptionKey = "sql-connection-string-option" + SQLDriverName OptionKey = "sql-driver-name-option" +) + +type SQLDB struct { + db *sql.DB + connectionString string +} + +func (s *SQLDB) Init(opts ...Option) error { + connString, sqlDriverName, err := processSQLOptions(opts...) + if err != nil { + return err + } + s.connectionString = connString + + db, err := sql.Open(sqlDriverName, connString) + if err != nil { + return err + } + + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS key_values ( + key varchar, + value varchar +);`) + if err != nil { + return err + } + + _, err = db.Exec(`CREATE INDEX idx_key_values ON key_values USING hash (key);`) + if err != nil { + return err + } + + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS namespaces ( + namespace varchar +);`) + if err != nil { + return err + } + + _, err = db.Exec(`CREATE INDEX idx_namespaces ON namespaces USING hash (namespace);`) + if err != nil { + return err + } + + s.db = db + return nil +} + +func processSQLOptions(opts ...Option) (connString string, sqlDriverName string, err error) { + if len(opts) != 2 { + return "", "", errors.New("sql options must contain connection string and driver name") + } + for _, opt := range opts { + switch opt.ID { + case SQLConnectionString: + maybeConnString, ok := opt.Option.(string) + if !ok { + err = errors.New("sql connection string must be a string") + return + } + if len(maybeConnString) == 0 { + err = errors.New("sql connection string must not be empty") + return + } + connString = maybeConnString + case SQLDriverName: + maybeDriverName, ok := opt.Option.(string) + if !ok { + err = errors.New("sql driver name must be a string") + return + } + if len(maybeDriverName) == 0 { + err = errors.New("sql driver name must not be empty") + return + } + sqlDriverName = maybeDriverName + } + } + if len(connString) == 0 || len(sqlDriverName) == 0 { + err = errors.New("sql connection string and driver name must not be empty") + return + } + return connString, sqlDriverName, nil +} + +func (s *SQLDB) Type() Type { + return DatabaseSQL +} + +func (s *SQLDB) URI() string { + return s.connectionString +} + +func (s *SQLDB) IsOpen() bool { + err := s.db.Ping() + if err != nil { + logrus.WithError(err).Error("pinging db") + return false + } + return true +} + +func (s *SQLDB) Close() error { + return s.db.Close() +} + +func (s *SQLDB) Write(ctx context.Context, namespace, key string, value []byte) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func(tx *sql.Tx) { + err := tx.Rollback() + if err != nil { + logrus.WithError(err).Error("unable to rollback") + } + }(tx) + + if err := write(ctx, tx, namespace, key, value); err != nil { + return err + } + + if err := tx.Commit(); err != nil { + return errors.Wrap(err, "committing transaction") + } + return nil +} + +type ExecContext interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + +func write(ctx context.Context, db ExecContext, namespace, key string, value []byte) error { + _, err := db.ExecContext(ctx, "INSERT INTO namespaces (namespace) VALUES ($1) EXCEPT SELECT namespace FROM namespaces WHERE namespace = $2", namespace, namespace) + if err != nil { + return err + } + _, err = db.ExecContext(ctx, "INSERT INTO key_values (key, value) VALUES ($1, $2)", Join(namespace, key), base64.RawStdEncoding.EncodeToString(value)) + return err +} + +func (s *SQLDB) WriteMany(ctx context.Context, namespaces, keys []string, values [][]byte) error { + stmt, err := s.db.Prepare("INSERT INTO key_values (key, value) VALUES ($1, $2)") + if err != nil { + return err + } + defer func(stmt *sql.Stmt) { + _ = stmt.Close() + }(stmt) + + for i, k := range keys { + _, err = stmt.ExecContext(ctx, Join(namespaces[i], k), base64.RawStdEncoding.EncodeToString(values[i])) + if err != nil { + return err + } + } + return err +} + +func (s *SQLDB) Read(ctx context.Context, namespace, key string) ([]byte, error) { + return read(ctx, s.db, namespace, key) +} + +type QueryRow interface { + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +func read(ctx context.Context, db QueryRow, namespace, key string) ([]byte, error) { + r := db.QueryRowContext(ctx, "SELECT value FROM key_values WHERE key = $1", Join(namespace, key)) + var value string + err := r.Scan(&value) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + decoded, err := base64.RawStdEncoding.DecodeString(value) + if err != nil { + return nil, err + } + return decoded, nil +} + +func (s *SQLDB) Exists(ctx context.Context, namespace, key string) (bool, error) { + query := ` + SELECT EXISTS ( + SELECT 1 + FROM key_values + WHERE key = $1 + LIMIT 1 + ) + ` + + // Execute the query and retrieve the result + var exists bool + err := s.db.QueryRowContext(ctx, query, Join(namespace, key)).Scan(&exists) + if err != nil { + return false, err + } + + return exists, nil +} + +func (s *SQLDB) ReadAll(ctx context.Context, namespace string) (map[string][]byte, error) { + rows, err := s.db.QueryContext(ctx, "SELECT key, value FROM key_values WHERE key LIKE $1", Join(namespace, "%")) + if err != nil { + return nil, err + } + defer func(rows *sql.Rows) { + err := rows.Close() + if err != nil { + logrus.WithError(err).Error("closing rows") + } + }(rows) + + allValues, _, err := readRowsAsMap(rows, namespace) + return allValues, err +} + +func readRowsAsMap(rows *sql.Rows, namespace string) (map[string][]byte, string, error) { + allValues := make(map[string][]byte) + var mapKey string + for rows.Next() { + var key string + var value string + if err := rows.Scan(&key, &value); err != nil { + return nil, "", err + } + mapKey = key[len(namespace)+1:] + decoded, err := base64.RawStdEncoding.DecodeString(value) + if err != nil { + return nil, "", err + } + allValues[mapKey] = decoded + } + if err := rows.Err(); err != nil { + return nil, "", err + } + return allValues, mapKey, nil +} + +func (s *SQLDB) ReadPage(ctx context.Context, namespace string, pageToken string, pageSize int) (results map[string][]byte, nextPageToken string, err error) { + var rows *sql.Rows + if pageSize == -1 { + rows, err = s.db.QueryContext(ctx, "SELECT * FROM key_values WHERE key LIKE $1 AND key >= $2 ORDER BY key", Join(namespace, "%"), pageToken) + } else { + rows, err = s.db.QueryContext(ctx, "SELECT * FROM key_values WHERE key LIKE $1 AND key >= $2 ORDER BY key LIMIT $3", Join(namespace, "%"), pageToken, pageSize+1) + } + if err != nil { + + if errors.Is(err, sql.ErrNoRows) { + return nil, "", nil + } + return nil, "", err + } + pageValues, lastMapKey, err := readRowsAsMap(rows, namespace) + if err != nil { + return nil, "", err + } + if pageSize == -1 { + nextPageToken = "" + } else { + if len(pageValues) <= pageSize { + nextPageToken = "" + } else { + nextPageToken = Join(namespace, lastMapKey) + delete(pageValues, lastMapKey) + } + } + return pageValues, nextPageToken, nil +} + +func (s *SQLDB) ReadPrefix(ctx context.Context, namespace, prefix string) (map[string][]byte, error) { + rows, err := s.db.QueryContext(ctx, "SELECT key, value FROM key_values WHERE key LIKE $1", Join(namespace, prefix)+"%") + if err != nil { + return nil, err + } + defer func(rows *sql.Rows) { + err := rows.Close() + if err != nil { + logrus.WithError(err).Error("closing rows") + } + }(rows) + + allValues, _, err := readRowsAsMap(rows, namespace) + return allValues, err +} + +func (s *SQLDB) ReadAllKeys(ctx context.Context, namespace string) ([]string, error) { + rows, err := s.db.QueryContext(ctx, "SELECT key FROM key_values WHERE key LIKE $1", Join(namespace, "%")) + if err != nil { + return nil, err + } + defer func(rows *sql.Rows) { + err := rows.Close() + if err != nil { + logrus.WithError(err).Error("closing rows") + } + }(rows) + + var keys []string + for rows.Next() { + var key string + if err := rows.Scan(&key); err != nil { + return nil, err + } + keys = append(keys, key[len(namespace)+1:]) + } + if err := rows.Err(); err != nil { + return nil, err + } + return keys, err +} + +func (s *SQLDB) Delete(ctx context.Context, namespace, key string) error { + row := s.db.QueryRowContext(ctx, "SELECT * FROM namespaces WHERE namespace = $1", namespace) + var gotNamespace string + if err := row.Scan(&gotNamespace); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return errors.Errorf("namespace<%s> does not exist", namespace) + } + return err + } + _, err := s.db.ExecContext(ctx, "DELETE FROM key_values WHERE key = $1", Join(namespace, key)) + if err != nil { + return err + } + return nil +} + +func (s *SQLDB) DeleteNamespace(ctx context.Context, namespace string) error { + row := s.db.QueryRowContext(ctx, "DELETE FROM namespaces WHERE namespace = $1 RETURNING *", namespace) + var namespaceRemoved string + if err := row.Scan(&namespaceRemoved); err != nil { + return errors.Wrap(err, "could not delete namespace") + } + + _, err := s.db.ExecContext(ctx, "DELETE FROM key_values WHERE key LIKE $1", Join(namespace, "%")) + if err != nil { + return err + } + return nil +} + +func (s *SQLDB) Update(ctx context.Context, namespace string, key string, values map[string]any) ([]byte, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer func(tx *sql.Tx) { + err := tx.Rollback() + if err != nil { + logrus.WithError(err).Error("unable to rollback") + } + }(tx) + updater := NewUpdater(values) + updatedValue, err := updateValue(ctx, namespace, key, updater, tx) + if err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + return updatedValue, nil +} + +func (s *SQLDB) UpdateValueAndOperation(ctx context.Context, namespace, key string, updater Updater, opNamespace, opKey string, opUpdater ResponseSettingUpdater) (first, op []byte, err error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, nil, err + } + defer func(tx *sql.Tx) { + err := tx.Rollback() + if err != nil { + logrus.WithError(err).Error("unable to rollback") + } + }(tx) + + updatedValue, err := updateValue(ctx, namespace, key, updater, tx) + if err != nil { + return nil, nil, err + } + + opUpdater.SetUpdatedResponse(updatedValue) + + updatedOpValue, err := updateValue(ctx, opNamespace, opKey, opUpdater, tx) + + if err := tx.Commit(); err != nil { + return nil, nil, err + } + return updatedValue, updatedOpValue, err +} + +func updateValue(ctx context.Context, namespace string, key string, updater Updater, tx *sql.Tx) ([]byte, error) { + currentValue, err := read(ctx, tx, namespace, key) + if err != nil { + return nil, err + } + if err := updater.Validate(currentValue); err != nil { + return nil, errors.Wrap(err, "validating update") + } + updatedValue, err := updater.Update(currentValue) + if err != nil { + return nil, err + } + encodedUpdatedValue := base64.RawStdEncoding.EncodeToString(updatedValue) + _, err = tx.ExecContext(ctx, "UPDATE key_values SET value = $1 WHERE key = $2", encodedUpdatedValue, Join(namespace, key)) + if err != nil { + return nil, err + } + return updatedValue, nil +} + +type sqlTx struct { + tx *sql.Tx +} + +func (s *sqlTx) Write(ctx context.Context, namespace, key string, value []byte) error { + return write(ctx, s.tx, namespace, key, value) +} + +func (s *SQLDB) Execute(ctx context.Context, businessLogicFunc BusinessLogicFunc, _ []WatchKey) (any, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer func(tx *sql.Tx) { + err := tx.Rollback() + if err != nil { + logrus.Errorf("problem rolling back %s", err) + } + }(tx) + + bTx := sqlTx{tx: tx} + + result, err := businessLogicFunc(ctx, &bTx) + if err != nil { + return nil, errors.Wrap(err, "executing business logic func") + } + + if err := tx.Commit(); err != nil { + return nil, errors.Wrap(err, "committing transaction") + } + return result, nil +} + +var _ Tx = (*sqlTx)(nil) +var _ ServiceStorage = (*SQLDB)(nil) diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index b29e37375..6674a60d7 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -23,8 +23,9 @@ type Tx interface { } const ( - Bolt Type = "bolt" - Redis Type = "redis" + Bolt Type = "bolt" + DatabaseSQL Type = "database_sql" + Redis Type = "redis" // Common options