diff --git a/README.md b/README.md index b8a346c..87d5d1e 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,26 @@ See [Godoc](https://pkg.go.dev/github.com/dwango/yashiro). go get github.com/dwango/yashiro ``` +### Authorization + +AWS + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "ssm:GetParameter", + "secretsmanager:GetSecretValue" + ], + "Resource": ["*"], + }, + ] +} +``` + ## CLI Tool ### Installation diff --git a/alias.go b/alias.go index 91ce7c3..69ce56d 100644 --- a/alias.go +++ b/alias.go @@ -30,4 +30,7 @@ type Config = config.Config var ( // NewEngine returns a new Engine. NewEngine = engine.New + + // IgnoreNotFound is an option to ignore missing external store values. + IgnoreNotFound = engine.IgnoreNotFound ) diff --git a/example_test.go b/example_test.go index edc6474..dfbf5e6 100644 --- a/example_test.go +++ b/example_test.go @@ -20,6 +20,7 @@ import ( "context" "log" "os" + "time" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/dwango/yashiro" @@ -36,6 +37,13 @@ func Example() { refName := "example" cfg := &config.Config{ + Global: config.GlobalConfig{ + EnableCache: true, // enable cache + Cache: config.CacheConfig{ + Type: config.CacheTypeMemory, + ExpireDuration: config.Duration(30 * time.Minute), + }, + }, Aws: &config.AwsConfig{ ParameterStoreValues: []config.AwsParameterStoreValueConfig{ { @@ -50,7 +58,7 @@ func Example() { }, } - eng, err := yashiro.NewEngine(cfg) + eng, err := yashiro.NewEngine(cfg, yashiro.IgnoreNotFound(true)) // ignore not found value if err != nil { log.Fatalf("failed to create engine: %s", err) } diff --git a/go.mod b/go.mod index fddb09b..dc10c9c 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.27.16 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.29.1 github.com/aws/aws-sdk-go-v2/service/ssm v1.50.4 + github.com/aws/aws-sdk-go-v2/service/sts v1.28.10 github.com/spf13/cobra v1.8.0 golang.org/x/crypto v0.3.0 sigs.k8s.io/yaml v1.4.0 @@ -25,7 +26,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.20.9 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.3 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.28.10 // indirect github.com/aws/smithy-go v1.20.2 // indirect github.com/google/uuid v1.1.1 // indirect github.com/huandu/xstrings v1.3.3 // indirect diff --git a/internal/client/aws.go b/internal/client/aws.go index fed09a7..dff40d6 100644 --- a/internal/client/aws.go +++ b/internal/client/aws.go @@ -21,39 +21,53 @@ import ( "errors" "fmt" - kms "github.com/aws/aws-sdk-go-v2/service/secretsmanager" - kmsTypes "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types" + "github.com/aws/aws-sdk-go-v2/aws" + secs "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + secsTypes "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types" "github.com/aws/aws-sdk-go-v2/service/ssm" ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/dwango/yashiro/internal/client/cache" "github.com/dwango/yashiro/internal/values" "github.com/dwango/yashiro/pkg/config" ) -type ssmClient interface { - GetParameter(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) -} - -type kmsClient interface { - GetSecretValue(ctx context.Context, params *kms.GetSecretValueInput, optFns ...func(*kms.Options)) (*kms.GetSecretValueOutput, error) -} - type awsClient struct { ssmClient ssmClient - kmsClient kmsClient + secsClient secsClient parameterStoreValue []config.AwsParameterStoreValueConfig secretsManagerValue []config.ValueConfig } -func newAwsClient(cfg *config.AwsConfig) (Client, error) { - if cfg.SdkConfig == nil { +func newAwsClient(cfg *config.Config) (Client, error) { + if cfg.Aws.SdkConfig == nil { return nil, fmt.Errorf("require aws sdk config") } + var cc cache.Cache + if cfg.Global.EnableCache { + // get AWS account ID + accountID, err := getAwsAccountId(cfg.Aws.SdkConfig) + if err != nil { + return nil, err + } + cc, err = cache.New(cfg.Global.Cache, cache.WithCacheKeys("aws", cfg.Aws.SdkConfig.Region, accountID)) + if err != nil { + return nil, err + } + } + return &awsClient{ - ssmClient: ssm.NewFromConfig(*cfg.SdkConfig), - kmsClient: kms.NewFromConfig(*cfg.SdkConfig), - parameterStoreValue: cfg.ParameterStoreValues, - secretsManagerValue: cfg.SecretsManagerValues, + ssmClient: &ssmClientWithCache{ + client: ssm.NewFromConfig(*cfg.Aws.SdkConfig), + cache: cc, + }, + secsClient: &secsClientWithCache{ + client: secs.NewFromConfig(*cfg.Aws.SdkConfig), + cache: cc, + }, + parameterStoreValue: cfg.Aws.ParameterStoreValues, + secretsManagerValue: cfg.Aws.SecretsManagerValues, }, nil } @@ -80,12 +94,12 @@ func (c awsClient) GetValues(ctx context.Context, ignoreNotFound bool) (values.V } for _, v := range c.secretsManagerValue { - output, err := c.kmsClient.GetSecretValue(ctx, &kms.GetSecretValueInput{ + output, err := c.secsClient.GetSecretValue(ctx, &secs.GetSecretValueInput{ SecretId: &v.Name, }) if err != nil { - var notFoundErr *kmsTypes.ResourceNotFoundException + var notFoundErr *secsTypes.ResourceNotFoundException if ignoreNotFound && errors.As(err, ¬FoundErr) { continue } @@ -99,3 +113,109 @@ func (c awsClient) GetValues(ctx context.Context, ignoreNotFound bool) (values.V return values, nil } + +type ssmClient interface { + GetParameter(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) +} + +type ssmClientWithCache struct { + client ssmClient + cache cache.Cache +} + +func (c ssmClientWithCache) GetParameter(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + if c.cache == nil { + return c.getParameter(ctx, params, optFns...) + } + + key := *params.Name // Name is required, so do not check nil + isSensitive := params.WithDecryption != nil && *params.WithDecryption + + // Load from cache. + value, expired, err := c.cache.Load(ctx, key, isSensitive) + if err != nil { + return nil, err + } + + // If a cache value is expired or not found, get a value from the external store. + if value == nil || expired { + output, err := c.getParameter(ctx, params, optFns...) + if err != nil { + return nil, err + } + + // Create or update cache. + if err := c.cache.Save(ctx, key, output.Parameter.Value, isSensitive); err != nil { + return nil, err + } + + return output, nil + } + + return &ssm.GetParameterOutput{Parameter: &ssmTypes.Parameter{Value: value}}, nil +} + +func (c ssmClientWithCache) getParameter(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + output, err := c.client.GetParameter(ctx, params, optFns...) + if err != nil { + return nil, err + } + return output, nil +} + +type secsClient interface { + GetSecretValue(ctx context.Context, params *secs.GetSecretValueInput, optFns ...func(*secs.Options)) (*secs.GetSecretValueOutput, error) +} + +type secsClientWithCache struct { + client secsClient + cache cache.Cache +} + +func (c secsClientWithCache) GetSecretValue(ctx context.Context, params *secs.GetSecretValueInput, optFns ...func(*secs.Options)) (*secs.GetSecretValueOutput, error) { + if c.cache == nil { + return c.getSecretValue(ctx, params, optFns...) + } + + key := *params.SecretId // SecretId is required, so do not check nil + + // Load from cache. Secret is always sensitive. + value, expired, err := c.cache.Load(ctx, key, true) + if err != nil { + return nil, err + } + + // If a cache value is expired or not found, get a value from the external store. + if value == nil || expired { + output, err := c.getSecretValue(ctx, params, optFns...) + if err != nil { + return nil, err + } + + // Create or update cache. + if err := c.cache.Save(ctx, key, output.SecretString, true); err != nil { + return nil, err + } + + return output, nil + } + + return &secs.GetSecretValueOutput{SecretString: value}, nil +} + +func (c secsClientWithCache) getSecretValue(ctx context.Context, params *secs.GetSecretValueInput, optFns ...func(*secs.Options)) (*secs.GetSecretValueOutput, error) { + output, err := c.client.GetSecretValue(ctx, params, optFns...) + if err != nil { + return nil, err + } + return output, nil +} + +func getAwsAccountId(sdkConfig *aws.Config) (string, error) { + stsClient := sts.NewFromConfig(*sdkConfig) + output, err := stsClient.GetCallerIdentity(context.Background(), &sts.GetCallerIdentityInput{}) + if err != nil { + return "", err + } + return *output.Account, nil +} diff --git a/internal/client/aws_test.go b/internal/client/aws_test.go index 6a101d6..7d82bbd 100644 --- a/internal/client/aws_test.go +++ b/internal/client/aws_test.go @@ -21,17 +21,19 @@ import ( "reflect" "testing" - kms "github.com/aws/aws-sdk-go-v2/service/secretsmanager" - kmsTypes "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types" + "github.com/aws/aws-sdk-go-v2/aws" + secs "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + secsTypes "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types" "github.com/aws/aws-sdk-go-v2/service/ssm" ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "github.com/dwango/yashiro/internal/client/cache" "github.com/dwango/yashiro/internal/values" "github.com/dwango/yashiro/pkg/config" ) func Test_newAwsClient(t *testing.T) { type args struct { - cfg *config.AwsConfig + cfg *config.Config } tests := []struct { name string @@ -42,7 +44,7 @@ func Test_newAwsClient(t *testing.T) { { name: "error: aws sdk config is nil", args: args{ - cfg: &config.AwsConfig{}, + cfg: &config.Config{Aws: &config.AwsConfig{}}, }, wantErr: true, }, @@ -67,16 +69,32 @@ func (m mockSsmClient) GetParameter(ctx context.Context, params *ssm.GetParamete return m(ctx, params, optFns...) } -type mockKmsClient func(ctx context.Context, params *kms.GetSecretValueInput, optFns ...func(*kms.Options)) (*kms.GetSecretValueOutput, error) +type mockSecsClient func(ctx context.Context, params *secs.GetSecretValueInput, optFns ...func(*secs.Options)) (*secs.GetSecretValueOutput, error) -func (m mockKmsClient) GetSecretValue(ctx context.Context, params *kms.GetSecretValueInput, optFns ...func(*kms.Options)) (*kms.GetSecretValueOutput, error) { +func (m mockSecsClient) GetSecretValue(ctx context.Context, params *secs.GetSecretValueInput, optFns ...func(*secs.Options)) (*secs.GetSecretValueOutput, error) { return m(ctx, params, optFns...) } +var ( + textStrSsmClient = mockSsmClient(func(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + return &ssm.GetParameterOutput{ + Parameter: &ssmTypes.Parameter{ + Value: stringPtr("test"), + }, + }, nil + }) + + textStrSecsClient = mockSecsClient(func(ctx context.Context, params *secs.GetSecretValueInput, optFns ...func(*secs.Options)) (*secs.GetSecretValueOutput, error) { + return &secs.GetSecretValueOutput{ + SecretString: stringPtr("test"), + }, nil + }) +) + func Test_awsClient_GetValues(t *testing.T) { type fields struct { ssmClient ssmClient - kmsClient kmsClient + secsClient secsClient parameterStoreValue []config.AwsParameterStoreValueConfig secretsManagerValue []config.ValueConfig } @@ -84,26 +102,12 @@ func Test_awsClient_GetValues(t *testing.T) { ctx context.Context ignoreNotFound bool } - returnStrPtr := func(s string) *string { return &s } - - textStrSsmClient := mockSsmClient(func(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { - return &ssm.GetParameterOutput{ - Parameter: &ssmTypes.Parameter{ - Value: returnStrPtr("test"), - }, - }, nil - }) - textStrKmsClient := mockKmsClient(func(ctx context.Context, params *kms.GetSecretValueInput, optFns ...func(*kms.Options)) (*kms.GetSecretValueOutput, error) { - return &kms.GetSecretValueOutput{ - SecretString: returnStrPtr("test"), - }, nil - }) notFoundErrSsmClient := mockSsmClient(func(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { return nil, &ssmTypes.ParameterNotFound{} }) - notFoundErrKmsClient := mockKmsClient(func(ctx context.Context, params *kms.GetSecretValueInput, optFns ...func(*kms.Options)) (*kms.GetSecretValueOutput, error) { - return nil, &kmsTypes.ResourceNotFoundException{} + notFoundErrSecsClient := mockSecsClient(func(ctx context.Context, params *secs.GetSecretValueInput, optFns ...func(*secs.Options)) (*secs.GetSecretValueOutput, error) { + return nil, &secsTypes.ResourceNotFoundException{} }) tests := []struct { @@ -116,20 +120,20 @@ func Test_awsClient_GetValues(t *testing.T) { { name: "ok: text", fields: fields{ - ssmClient: textStrSsmClient, - kmsClient: textStrKmsClient, + ssmClient: textStrSsmClient, + secsClient: textStrSecsClient, parameterStoreValue: []config.AwsParameterStoreValueConfig{ {ValueConfig: config.ValueConfig{Name: "ssmKey"}, Decryption: nil}, }, secretsManagerValue: []config.ValueConfig{ - {Name: "kmsKey"}, + {Name: "secsKey"}, }, }, args: args{ ctx: context.Background(), ignoreNotFound: false, }, - want: values.Values{"ssmKey": "test", "kmsKey": "test"}, + want: values.Values{"ssmKey": "test", "secsKey": "test"}, }, { name: "ok: json", @@ -137,38 +141,38 @@ func Test_awsClient_GetValues(t *testing.T) { ssmClient: mockSsmClient(func(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { return &ssm.GetParameterOutput{ Parameter: &ssmTypes.Parameter{ - Value: returnStrPtr(`{"key":"value"}`), + Value: stringPtr(`{"key":"value"}`), }, }, nil }), - kmsClient: mockKmsClient(func(ctx context.Context, params *kms.GetSecretValueInput, optFns ...func(*kms.Options)) (*kms.GetSecretValueOutput, error) { - return &kms.GetSecretValueOutput{ - SecretString: returnStrPtr(`{"key":"value"}`), + secsClient: mockSecsClient(func(ctx context.Context, params *secs.GetSecretValueInput, optFns ...func(*secs.Options)) (*secs.GetSecretValueOutput, error) { + return &secs.GetSecretValueOutput{ + SecretString: stringPtr(`{"key":"value"}`), }, nil }), parameterStoreValue: []config.AwsParameterStoreValueConfig{ {ValueConfig: config.ValueConfig{Name: "ssmKey", IsJSON: true}}, }, secretsManagerValue: []config.ValueConfig{ - {Name: "kmsKey", IsJSON: true}, + {Name: "secsKey", IsJSON: true}, }, }, args: args{ ctx: context.Background(), ignoreNotFound: false, }, - want: values.Values{"ssmKey": map[string]any{"key": "value"}, "kmsKey": map[string]any{"key": "value"}}, + want: values.Values{"ssmKey": map[string]any{"key": "value"}, "secsKey": map[string]any{"key": "value"}}, }, { name: "ok: ignore not found error", fields: fields{ - ssmClient: notFoundErrSsmClient, - kmsClient: notFoundErrKmsClient, + ssmClient: notFoundErrSsmClient, + secsClient: notFoundErrSecsClient, parameterStoreValue: []config.AwsParameterStoreValueConfig{ {ValueConfig: config.ValueConfig{Name: "ssmKey"}}, }, secretsManagerValue: []config.ValueConfig{ - {Name: "kmsKey"}, + {Name: "secsKey"}, }, }, args: args{ @@ -180,8 +184,8 @@ func Test_awsClient_GetValues(t *testing.T) { { name: "error: return not found from ssm", fields: fields{ - ssmClient: notFoundErrSsmClient, - kmsClient: textStrKmsClient, + ssmClient: notFoundErrSsmClient, + secsClient: textStrSecsClient, parameterStoreValue: []config.AwsParameterStoreValueConfig{ {ValueConfig: config.ValueConfig{Name: "ssmKey"}}, }, @@ -194,13 +198,13 @@ func Test_awsClient_GetValues(t *testing.T) { wantErr: true, }, { - name: "error: return not found from kms", + name: "error: return not found from secs", fields: fields{ ssmClient: textStrSsmClient, - kmsClient: notFoundErrKmsClient, + secsClient: notFoundErrSecsClient, parameterStoreValue: []config.AwsParameterStoreValueConfig{}, secretsManagerValue: []config.ValueConfig{ - {Name: "kmsKey"}, + {Name: "secsKey"}, }, }, args: args{ @@ -215,7 +219,7 @@ func Test_awsClient_GetValues(t *testing.T) { ssmClient: mockSsmClient(func(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { return nil, &ssmTypes.InternalServerError{} }), - kmsClient: textStrKmsClient, + secsClient: textStrSecsClient, parameterStoreValue: []config.AwsParameterStoreValueConfig{ {ValueConfig: config.ValueConfig{Name: "ssmKey"}}, }, @@ -228,15 +232,15 @@ func Test_awsClient_GetValues(t *testing.T) { wantErr: true, }, { - name: "error: return another error from kms", + name: "error: return another error from secs", fields: fields{ ssmClient: textStrSsmClient, - kmsClient: mockKmsClient(func(ctx context.Context, params *kms.GetSecretValueInput, optFns ...func(*kms.Options)) (*kms.GetSecretValueOutput, error) { - return nil, &kmsTypes.InternalServiceError{} + secsClient: mockSecsClient(func(ctx context.Context, params *secs.GetSecretValueInput, optFns ...func(*secs.Options)) (*secs.GetSecretValueOutput, error) { + return nil, &secsTypes.InternalServiceError{} }), parameterStoreValue: []config.AwsParameterStoreValueConfig{}, secretsManagerValue: []config.ValueConfig{ - {Name: "kmsKey"}, + {Name: "secsKey"}, }, }, args: args{ @@ -250,7 +254,7 @@ func Test_awsClient_GetValues(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := awsClient{ ssmClient: tt.fields.ssmClient, - kmsClient: tt.fields.kmsClient, + secsClient: tt.fields.secsClient, parameterStoreValue: tt.fields.parameterStoreValue, secretsManagerValue: tt.fields.secretsManagerValue, } @@ -265,3 +269,203 @@ func Test_awsClient_GetValues(t *testing.T) { }) } } + +func Test_ssmClientWithCache_GetParameter(t *testing.T) { + var params = &ssm.GetParameterInput{Name: stringPtr("any")} + var textStrSsmClientWant = &ssm.GetParameterOutput{Parameter: &ssmTypes.Parameter{Value: stringPtr("test")}} + + type fields struct { + client ssmClient + cache cache.Cache + } + type args struct { + ctx context.Context + params *ssm.GetParameterInput + optFns []func(*ssm.Options) + } + tests := []struct { + name string + fields fields + args args + want *ssm.GetParameterOutput + wantErr bool + }{ + { + name: "ok: get from cache", + fields: fields{ + client: nil, + cache: mockCache{load: mockLoadFunc, save: mockSaveFunc}, + }, + args: args{ + ctx: context.Background(), + params: params, + }, + want: &ssm.GetParameterOutput{Parameter: &ssmTypes.Parameter{Value: stringPtr("value")}}, + }, + { + name: "ok: get from cache(cache disabled)", + fields: fields{ + client: textStrSsmClient, + cache: nil, + }, + args: args{ + ctx: context.Background(), + params: &ssm.GetParameterInput{Name: stringPtr("key")}, + }, + want: textStrSsmClientWant, + }, + { + name: "ok: get from client(no cache)", + fields: fields{ + client: textStrSsmClient, + cache: mockCache{load: mockLoadFuncNotFound, save: mockSaveFunc}, + }, + args: args{ + ctx: context.Background(), + params: params, + }, + want: textStrSsmClientWant, + }, + { + name: "ok: get from client(cache expired)", + fields: fields{ + client: textStrSsmClient, + cache: mockCache{load: mockLoadFuncExpired, save: mockSaveFunc}, + }, + args: args{ + ctx: context.Background(), + params: params, + }, + want: textStrSsmClientWant, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := ssmClientWithCache{ + client: tt.fields.client, + cache: tt.fields.cache, + } + got, err := c.GetParameter(tt.args.ctx, tt.args.params, tt.args.optFns...) + if (err != nil) != tt.wantErr { + t.Errorf("ssmClientWithCache.GetParameter() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ssmClientWithCache.GetParameter() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_secsClientWithCache_GetSecretValue(t *testing.T) { + var params = &secs.GetSecretValueInput{SecretId: stringPtr("any")} + var textStrSecsClientWant = &secs.GetSecretValueOutput{SecretString: stringPtr("test")} + + type fields struct { + client secsClient + cache cache.Cache + } + type args struct { + ctx context.Context + params *secs.GetSecretValueInput + optFns []func(*secs.Options) + } + tests := []struct { + name string + fields fields + args args + want *secs.GetSecretValueOutput + wantErr bool + }{ + { + name: "ok: get from cache", + fields: fields{ + client: nil, + cache: mockCache{load: mockLoadFunc, save: mockSaveFunc}, + }, + args: args{ + ctx: context.Background(), + params: params, + }, + want: &secs.GetSecretValueOutput{SecretString: stringPtr("value")}, + }, + { + name: "ok: get from cache(cache disabled)", + fields: fields{ + client: textStrSecsClient, + cache: nil, + }, + args: args{ + ctx: context.Background(), + params: params, + }, + want: textStrSecsClientWant, + }, + { + name: "ok: get from client(no cache)", + fields: fields{ + client: textStrSecsClient, + cache: mockCache{load: mockLoadFuncNotFound, save: mockSaveFunc}, + }, + args: args{ + ctx: context.Background(), + params: params, + }, + want: textStrSecsClientWant, + }, + { + name: "ok: get from client(cache expired)", + fields: fields{ + client: textStrSecsClient, + cache: mockCache{load: mockLoadFuncExpired, save: mockSaveFunc}, + }, + args: args{ + ctx: context.Background(), + params: params, + }, + want: textStrSecsClientWant, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := secsClientWithCache{ + client: tt.fields.client, + cache: tt.fields.cache, + } + got, err := c.GetSecretValue(tt.args.ctx, tt.args.params, tt.args.optFns...) + if (err != nil) != tt.wantErr { + t.Errorf("secsClientWithCache.GetSecretValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("secsClientWithCache.GetSecretValue() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getAwsAccountId(t *testing.T) { + type args struct { + sdkConfig *aws.Config + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getAwsAccountId(tt.args.sdkConfig) + if (err != nil) != tt.wantErr { + t.Errorf("getAwsAccountId() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("getAwsAccountId() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/client/cache.go b/internal/client/cache.go deleted file mode 100644 index b15cc07..0000000 --- a/internal/client/cache.go +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2024 DWANGO Co., Ltd. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package client - -import ( - "context" - - "github.com/dwango/yashiro/internal/client/cache" - "github.com/dwango/yashiro/internal/values" -) - -type clientWithCache struct { - client Client - cache cache.Cache -} - -func newClientWithCache(client Client, cache cache.Cache) Client { - return &clientWithCache{ - client: client, - cache: cache, - } -} - -// GetValues implements Client. -func (c *clientWithCache) GetValues(ctx context.Context, ignoreNotFound bool) (values.Values, error) { - val, expired, err := c.cache.Load(ctx) - if err != nil { - return nil, err - } - - // if cache is empty, get values from external store. - if len(val) == 0 || expired { - val, err = c.client.GetValues(ctx, ignoreNotFound) - if err != nil { - return nil, err - } - } - - // save values to cache - if expired { - if err := c.cache.Save(ctx, val); err != nil { - return nil, err - } - } - - return val, nil -} diff --git a/internal/client/cache/cache.go b/internal/client/cache/cache.go index 3dfc2b3..a337911 100644 --- a/internal/client/cache/cache.go +++ b/internal/client/cache/cache.go @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package cache import ( "context" + "encoding/hex" "errors" "fmt" + "time" - "github.com/dwango/yashiro/internal/values" "github.com/dwango/yashiro/pkg/config" ) @@ -29,21 +31,30 @@ var ( ) type Cache interface { - // Load returns values from cache and whether or not cache is expired. If cache is empty, - // returned values is empty and expired=true. - Load(ctx context.Context) (values.Values, bool, error) + // Load returns cached string by using the key, and whether or not cache is expired. If a cache is empty, + // returned a string is nil and expired is true. + Load(ctx context.Context, key string, decrypt bool) (*string, bool, error) - // Save saves values to cache. - Save(ctx context.Context, val values.Values) error + // Save saves value to cache. If encrypt is true, value is encrypted before saving. + Save(ctx context.Context, key string, value *string, encrypt bool) error } -func New(cfg config.CacheConfig) (Cache, error) { +func New(cfg config.CacheConfig, options ...Option) (Cache, error) { + expireDuration := config.DefaultExpireDuration + if cfg.ExpireDuration != 0 { + expireDuration = time.Duration(cfg.ExpireDuration) + } + switch cfg.Type { case config.CacheTypeUnspecified, config.CacheTypeMemory: - return newMemoryCache() + return newMemoryCache(expireDuration) case config.CacheTypeFile: - return newFileCache(cfg.File) + return newFileCache(cfg.File, expireDuration, options...) default: return nil, fmt.Errorf("%w: %s", ErrInvalidCacheType, cfg.Type) } } + +func keyToHex(key string) string { + return hex.EncodeToString([]byte(key)) +} diff --git a/internal/client/cache/cache_test.go b/internal/client/cache/cache_test.go new file mode 100644 index 0000000..4793da7 --- /dev/null +++ b/internal/client/cache/cache_test.go @@ -0,0 +1,66 @@ +/** + * Copyright 2024 DWANGO Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cache + +import ( + "reflect" + "testing" + "time" + + "github.com/dwango/yashiro/pkg/config" +) + +func TestNew(t *testing.T) { + type args struct { + cfg config.CacheConfig + options []Option + } + tests := []struct { + name string + args args + want Cache + wantErr bool + }{ + { + name: "error: unknown cache type", + args: args{ + cfg: config.CacheConfig{ + Type: "unknown", + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := New(tt.args.cfg, tt.args.options...) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("New() = %v, want %v", got, tt.want) + } + }) + } +} + +const notExpireDuration = 100 * 365 * 24 * time.Hour // 100 years + +func stringPtr(s string) *string { + return &s +} diff --git a/internal/client/cache/file.go b/internal/client/cache/file.go index 61cb315..a50c559 100644 --- a/internal/client/cache/file.go +++ b/internal/client/cache/file.go @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package cache import ( @@ -20,41 +21,51 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" - "encoding/json" "errors" "io" "os" + "path/filepath" + "strings" "time" - "github.com/dwango/yashiro/internal/values" "github.com/dwango/yashiro/pkg/config" "golang.org/x/crypto/bcrypt" ) const ( - cacheFileName = "values" keyFileName = "key" - keyHashFileName = ".key_hash" + keyHashFileName = "keyHash" ) var defaultCacheBasePath string type fileCache struct { - cacheBasePath string - cipherBlock cipher.Block - expired bool + cachePath string + cipherBlock cipher.Block + expireDuration time.Duration + filenamePrefix string } -func newFileCache(cfg config.FileCacheConfig) (Cache, error) { - fc := &fileCache{ - cacheBasePath: defaultCacheBasePath, +func newFileCache(cfg config.FileCacheConfig, expireDuration time.Duration, options ...Option) (Cache, error) { + opts := defaultOpts + for _, o := range options { + o(opts) } + cachePath := defaultCacheBasePath if len(cfg.CachePath) != 0 { - fc.cacheBasePath = cfg.CachePath + cachePath = cfg.CachePath + } + filenamePrefix := keyToHex(strings.Join(opts.CacheKeys, "_")) + "_" + + fc := &fileCache{ + cachePath: cachePath, + expireDuration: expireDuration, + filenamePrefix: filenamePrefix, } + // create cache directory - if err := os.MkdirAll(fc.cacheBasePath, 0777); err != nil { + if err := os.MkdirAll(fc.cachePath, 0777); err != nil { return nil, err } @@ -72,46 +83,54 @@ func newFileCache(cfg config.FileCacheConfig) (Cache, error) { return fc, nil } -const ( - // 30 days - expiredDuration time.Duration = 30 * 24 * time.Hour -) - // Load implements Cache. -func (f *fileCache) Load(_ context.Context) (values.Values, bool, error) { - fInfo, err := f.getFileStat(cacheFileName) +func (f *fileCache) Load(_ context.Context, key string, decrypt bool) (*string, bool, error) { + filename := keyToHex(key) + + fInfo, err := f.getFileInfo(filename, false) if err != nil { - f.expired = true - return nil, f.expired, nil + // cache file not found + return nil, false, nil } // check if cache is expired - f.expired = time.Since(fInfo.ModTime().Local()) > expiredDuration + expired := time.Since(fInfo.ModTime().Local()) > f.expireDuration - cacheCipherText, err := f.readFile(cacheFileName) + valueByte, err := f.readFile(filename, false) if err != nil { return nil, false, err } + if !decrypt { + value := string(valueByte) + return &value, expired, nil + } - val, err := f.decryptCache(cacheCipherText) + valueByte, err = f.decryptCache(valueByte) if err != nil { return nil, false, err } + value := string(valueByte) - return val, f.expired, nil + return &value, expired, nil } // Save implements Cache. -func (f *fileCache) Save(_ context.Context, val values.Values) error { - if !f.expired { +func (f *fileCache) Save(_ context.Context, key string, value *string, encrypt bool) error { + if value == nil { return nil } - encryptedCache, err := f.encryptCache(val) - if err != nil { - return err + filename := keyToHex(key) + valueByte := []byte(*value) + + if encrypt { + var err error + valueByte, err = f.encryptCache(valueByte) + if err != nil { + return err + } } - if err := f.writeToFile(cacheFileName, encryptedCache); err != nil { + if err := f.writeToFile(filename, valueByte, false); err != nil { return err } @@ -121,14 +140,14 @@ func (f *fileCache) Save(_ context.Context, val values.Values) error { func (f *fileCache) readOrCreateKey() ([]byte, error) { var key []byte // check key file exists - if _, err := f.getFileStat(keyFileName); err != nil { + if _, err := f.getFileInfo(keyFileName, false); err != nil { key = make([]byte, 32) // create key file if _, err := rand.Read(key); err != nil { return nil, err } - if err := f.writeToFile(keyFileName, key); err != nil { + if err := f.writeToFile(keyFileName, key, false); err != nil { return nil, err } @@ -137,7 +156,7 @@ func (f *fileCache) readOrCreateKey() ([]byte, error) { if err != nil { return nil, err } - if err := f.writeToFile(keyHashFileName, keyHash); err != nil { + if err := f.writeToFile(keyHashFileName, keyHash, true); err != nil { return nil, err } @@ -146,13 +165,13 @@ func (f *fileCache) readOrCreateKey() ([]byte, error) { var err error // read key file - key, err = f.readFile(keyFileName) + key, err = f.readFile(keyFileName, false) if err != nil { return nil, err } // read key hash file - keyHash, err := f.readFile(keyHashFileName) + keyHash, err := f.readFile(keyHashFileName, true) if err != nil { return nil, err } @@ -165,59 +184,70 @@ func (f *fileCache) readOrCreateKey() ([]byte, error) { return key, nil } -func (f *fileCache) decryptCache(cacheCipherText []byte) (values.Values, error) { - if len(cacheCipherText) < aes.BlockSize { +func (f *fileCache) decryptCache(cipherText []byte) ([]byte, error) { + if len(cipherText) < aes.BlockSize { return nil, errors.New("ciphertext too short") } - iv := cacheCipherText[:aes.BlockSize] - cacheCipherText = cacheCipherText[aes.BlockSize:] - - cachePlainText := make([]byte, len(cacheCipherText)) - stream := cipher.NewOFB(f.cipherBlock, iv) - stream.XORKeyStream(cachePlainText, cacheCipherText) + iv := cipherText[:aes.BlockSize] + cipherText = cipherText[aes.BlockSize:] - values := make(values.Values) - if err := json.Unmarshal(cachePlainText, &values); err != nil { - return nil, err - } + stream := cipher.NewCFBDecrypter(f.cipherBlock, iv) + stream.XORKeyStream(cipherText, cipherText) - return values, nil + return cipherText, nil } -func (f *fileCache) encryptCache(values values.Values) ([]byte, error) { - cacheJSON, err := json.Marshal(values) - if err != nil { - return nil, err - } - - cacheCipherText := make([]byte, aes.BlockSize+len(cacheJSON)) - iv := cacheCipherText[:aes.BlockSize] +func (f *fileCache) encryptCache(plainText []byte) ([]byte, error) { + cipherText := make([]byte, aes.BlockSize+len(plainText)) + iv := cipherText[:aes.BlockSize] if _, err := io.ReadFull(rand.Reader, iv); err != nil { return nil, err } - stream := cipher.NewOFB(f.cipherBlock, iv) - stream.XORKeyStream(cacheCipherText[aes.BlockSize:], cacheJSON) + stream := cipher.NewCFBEncrypter(f.cipherBlock, iv) + stream.XORKeyStream(cipherText[aes.BlockSize:], plainText) - return cacheCipherText, nil + return cipherText, nil } -func (f fileCache) getFileStat(filename string) (os.FileInfo, error) { - return os.Stat(f.cacheBasePath + "/" + filename) +func (f fileCache) getFileInfo(filename string, hidden bool) (os.FileInfo, error) { + filename = f.filenamePrefix + filename + if hidden { + filename = "." + filename + } + + return os.Stat((filepath.Join(f.cachePath, filename))) } -func (f fileCache) readFile(filename string) ([]byte, error) { - return os.ReadFile(f.cacheBasePath + "/" + filename) +func (f fileCache) readFile(filename string, hidden bool) ([]byte, error) { + filename = f.filenamePrefix + filename + if hidden { + filename = "." + filename + } + + data, err := os.ReadFile(filepath.Join(f.cachePath, filename)) + if err != nil { + return nil, err + } + data = data[:len(data)-1] + + return data, nil } -func (f fileCache) writeToFile(filename string, data []byte) error { - file, err := os.Create(f.cacheBasePath + "/" + filename) +func (f fileCache) writeToFile(filename string, data []byte, hidden bool) error { + filename = f.filenamePrefix + filename + if hidden { + filename = "." + filename + } + + file, err := os.Create(filepath.Join(f.cachePath, filename)) if err != nil { return err } defer file.Close() + data = append(data, '\n') if _, err := file.Write(data); err != nil { return err } @@ -226,12 +256,13 @@ func (f fileCache) writeToFile(filename string, data []byte) error { } func init() { - const cachePath = "/yashiro" + const cachePath = "yashiro" cacheDir, err := os.UserCacheDir() if err != nil { - defaultCacheBasePath = "/tmp" + cachePath + "/cache" + defaultCacheBasePath = filepath.Join(os.TempDir(), cachePath, "cache") return } - defaultCacheBasePath = cacheDir + cachePath + + defaultCacheBasePath = filepath.Join(cacheDir, cachePath) } diff --git a/internal/client/cache/file_test.go b/internal/client/cache/file_test.go index 1e6e745..330ffc3 100644 --- a/internal/client/cache/file_test.go +++ b/internal/client/cache/file_test.go @@ -13,196 +13,168 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package cache import ( "context" "crypto/aes" "crypto/cipher" + "os" "reflect" "testing" + "time" - "github.com/dwango/yashiro/internal/values" "github.com/dwango/yashiro/pkg/config" ) func Test_newFileCache(t *testing.T) { type args struct { - cfg config.FileCacheConfig + cfg config.FileCacheConfig + expireDuration time.Duration + options []Option } tests := []struct { - name string - args args - want Cache - wantErr bool + name string + args args + wantFiles []string + wantErr bool }{ - // TODO: Add test cases. + { + name: "ok", + args: args{ + cfg: config.FileCacheConfig{ + CachePath: "testdata/constructor", + }, + }, + wantFiles: []string{"testdata/constructor/_key", "testdata/constructor/._keyHash"}, + }, + { + name: "ok with cache keys option", + args: args{ + cfg: config.FileCacheConfig{ + CachePath: "testdata/constructor", + }, + options: []Option{WithCacheKeys("key1")}, + }, + wantFiles: []string{"testdata/constructor/6b657931_key", "testdata/constructor/.6b657931_keyHash"}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := newFileCache(tt.args.cfg) + _, err := newFileCache(tt.args.cfg, tt.args.expireDuration, tt.args.options...) if (err != nil) != tt.wantErr { t.Errorf("newFileCache() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("newFileCache() = %v, want %v", got, tt.want) + for _, file := range tt.wantFiles { + if _, err := os.Stat(file); err != nil { + t.Errorf("file is not found = %v", file) + } } }) } } func Test_fileCache_SaveAndLoad(t *testing.T) { + const cachePath = "testdata/save-and-load" block, _ := aes.NewCipher([]byte("0123456789abcdef0123456789abcdef")) type fields struct { - cacheBasePath string - cipherBlock cipher.Block - expired bool + cachePath string + cipherBlock cipher.Block + expireDuration time.Duration + filenamePrefix string } type args struct { - in0 context.Context - val values.Values + in0 context.Context + key string + value *string + encrypt bool } tests := []struct { - name string - fields fields - args args - want values.Values - wantErr bool + name string + fields fields + args args + wantExpired bool + wantErr bool }{ { - name: "ok: save and load", + name: "ok: save and load plain value", fields: fields{ - cacheBasePath: "testdata/save-and-load", - cipherBlock: block, - expired: true, + cachePath: cachePath, + cipherBlock: block, + expireDuration: notExpireDuration, }, args: args{ - in0: context.Background(), - val: values.Values{ - "key": "value", - }, + key: "plain-key", + value: stringPtr("plain-value"), }, - want: values.Values{ - "key": "value", - }, - wantErr: false, }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - f := &fileCache{ - cacheBasePath: tt.fields.cacheBasePath, - cipherBlock: tt.fields.cipherBlock, - expired: tt.fields.expired, - } - if err := f.Save(tt.args.in0, tt.args.val); (err != nil) != tt.wantErr { - t.Errorf("fileCache.Save() error = %v, wantErr %v", err, tt.wantErr) - } - got, _, err := f.Load(tt.args.in0) - if (err != nil) != tt.wantErr { - t.Errorf("fileCache.Load() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("fileCache.Load() got = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_fileCache_readOrCreateKey(t *testing.T) { - type fields struct { - cacheBasePath string - cipherBlock cipher.Block - expired bool - } - tests := []struct { - name string - fields fields - wantErr bool - }{ { - // This test case is executed only once. Therefore, if you want to retest, delete the file - // before executing it again. - name: "ok: create key", + name: "ok: save and load encrypted value", fields: fields{ - cacheBasePath: "testdata/read-or-create-key", + cachePath: cachePath, + cipherBlock: block, + expireDuration: notExpireDuration, + }, + args: args{ + key: "encrypted-key", + value: stringPtr("encrypted-value"), + encrypt: true, }, - wantErr: false, }, { - name: "ok: read key", + name: "ok: load expired value", fields: fields{ - cacheBasePath: "testdata/read-or-create-key", + cachePath: cachePath, + cipherBlock: block, + expireDuration: 0, + }, + args: args{ + key: "expired-key", + value: stringPtr("expired-value"), }, - wantErr: false, + wantExpired: true, }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - f := &fileCache{ - cacheBasePath: tt.fields.cacheBasePath, - cipherBlock: tt.fields.cipherBlock, - expired: tt.fields.expired, - } - if _, err := f.readOrCreateKey(); (err != nil) != tt.wantErr { - t.Errorf("fileCache.readOrCreateKey() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} - -func Test_fileCache_encryptAndDecryptCache(t *testing.T) { - block, _ := aes.NewCipher([]byte("0123456789abcdef0123456789abcdef")) - - type fields struct { - cacheBasePath string - cipherBlock cipher.Block - expired bool - } - type args struct { - values values.Values - } - tests := []struct { - name string - fields fields - args args - wantErr bool - }{ { - name: "ok: encrypt values", + name: "ok: with prefix", fields: fields{ - cipherBlock: block, + cachePath: cachePath, + cipherBlock: block, + expireDuration: notExpireDuration, + filenamePrefix: "test_", }, args: args{ - values: values.Values{ - "key": "value", - }, + key: "prefix-key", + value: stringPtr("prefix-value"), }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { f := &fileCache{ - cacheBasePath: tt.fields.cacheBasePath, - cipherBlock: tt.fields.cipherBlock, - expired: tt.fields.expired, + cachePath: tt.fields.cachePath, + cipherBlock: tt.fields.cipherBlock, + expireDuration: tt.fields.expireDuration, + filenamePrefix: tt.fields.filenamePrefix, } - gotEncrypt, err := f.encryptCache(tt.args.values) - if (err != nil) != tt.wantErr { - t.Errorf("fileCache.encryptCache() error = %v, wantErr %v", err, tt.wantErr) - return + // Save + if err := f.Save(tt.args.in0, tt.args.key, tt.args.value, tt.args.encrypt); (err != nil) != tt.wantErr { + t.Errorf("fileCache.Save() error = %v, wantErr %v", err, tt.wantErr) } - gotDecrypt, err := f.decryptCache(gotEncrypt) + + // Load + gotValue, gotExpired, err := f.Load(tt.args.in0, tt.args.key, tt.args.encrypt) if (err != nil) != tt.wantErr { - t.Errorf("fileCache.decryptCache() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("fileCache.Load() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(gotDecrypt, tt.args.values) { - t.Errorf("fileCache.decryptCache() = %v, want %v", gotDecrypt, tt.args.values) + if gotExpired != tt.wantExpired { + t.Errorf("fileCache.Load() expired = %v, want %v", gotExpired, tt.wantExpired) + } + if !reflect.DeepEqual(gotValue, tt.args.value) { + t.Errorf("fileCache.Load() got = %v, want %v", *gotValue, *tt.args.value) } }) } diff --git a/internal/client/cache/memory.go b/internal/client/cache/memory.go index ff6b3a9..c33200c 100644 --- a/internal/client/cache/memory.go +++ b/internal/client/cache/memory.go @@ -13,38 +13,66 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package cache import ( "context" - "maps" - - "github.com/dwango/yashiro/internal/values" + "strings" + "time" ) type memoryCache struct { - values values.Values + caches map[string]*cacheData + expireDuration time.Duration + keyPrefix string +} + +func newMemoryCache(expireDuration time.Duration, options ...Option) (Cache, error) { + opts := defaultOpts + for _, o := range options { + o(opts) + } + + keyPrefix := keyToHex(strings.Join(opts.CacheKeys, "_")) + "_" + + return &memoryCache{ + caches: make(map[string]*cacheData), + expireDuration: expireDuration, + keyPrefix: keyPrefix, + }, nil } -func newMemoryCache() (Cache, error) { - return &memoryCache{}, nil +type cacheData struct { + value string + saveTime time.Time } // Load implements Cache. -func (m memoryCache) Load(_ context.Context) (values.Values, bool, error) { - expired := false - if len(m.values) == 0 { - expired = true +func (m memoryCache) Load(_ context.Context, key string, _ bool) (*string, bool, error) { + data, ok := m.caches[m.keyPrefix+key] + if !ok { + return nil, false, nil + } + + if time.Since(data.saveTime) > m.expireDuration { + return &data.value, true, nil } - return m.values, expired, nil + return &data.value, false, nil } // Save implements Cache. -func (m *memoryCache) Save(_ context.Context, val values.Values) error { - newVal := make(values.Values, len(val)) - maps.Copy(newVal, val) - m.values = newVal +func (m *memoryCache) Save(_ context.Context, key string, value *string, _ bool) error { + if value == nil { + return nil + } + + data := &cacheData{ + value: *value, + saveTime: time.Now(), + } + m.caches[m.keyPrefix+key] = data return nil } diff --git a/internal/client/cache/memory_test.go b/internal/client/cache/memory_test.go index 95fdd10..3360d7d 100644 --- a/internal/client/cache/memory_test.go +++ b/internal/client/cache/memory_test.go @@ -13,19 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package cache import ( "context" "reflect" "testing" - - "github.com/dwango/yashiro/internal/values" + "time" ) func Test_newMemoryCache(t *testing.T) { + type args struct { + expireDuration time.Duration + options []Option + } tests := []struct { name string + args args want Cache wantErr bool }{ @@ -33,7 +38,7 @@ func Test_newMemoryCache(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := newMemoryCache() + got, err := newMemoryCache(tt.args.expireDuration, tt.args.options...) if (err != nil) != tt.wantErr { t.Errorf("newMemoryCache() error = %v, wantErr %v", err, tt.wantErr) return @@ -45,104 +50,84 @@ func Test_newMemoryCache(t *testing.T) { } } -func Test_memoryCache_Load(t *testing.T) { +func Test_memoryCache_SaveAndLoad(t *testing.T) { type fields struct { - values values.Values + caches map[string]*cacheData + expireDuration time.Duration + keyPrefix string } type args struct { - in0 context.Context + in0 context.Context + key string + value *string + in3 bool } tests := []struct { - name string - fields fields - args args - want values.Values - want1 bool - wantErr bool + name string + fields fields + args args + wantExpired bool + wantErr bool }{ { - name: "ok: get values", + name: "ok: save and load", fields: fields{ - values: values.Values{ - "key": "value", - }, + caches: make(map[string]*cacheData), + expireDuration: notExpireDuration, }, args: args{ - in0: context.Background(), - }, - want: values.Values{ - "key": "value", + key: "key", + value: stringPtr("value"), }, }, { - name: "ok: no values(return expired=true)", + name: "ok: load expired value", fields: fields{ - values: nil, + caches: make(map[string]*cacheData), + expireDuration: 0, }, args: args{ - in0: context.Background(), + key: "expired-key", + value: stringPtr("expired-value"), }, - want: nil, - want1: true, + wantExpired: true, }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := memoryCache{ - values: tt.fields.values, - } - got, got1, err := m.Load(tt.args.in0) - if (err != nil) != tt.wantErr { - t.Errorf("memoryCache.Load() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("memoryCache.Load() got = %v, want %v", got, tt.want) - } - if got1 != tt.want1 { - t.Errorf("memoryCache.Load() got1 = %v, want %v", got1, tt.want1) - } - }) - } -} - -func Test_memoryCache_Save(t *testing.T) { - type fields struct { - values values.Values - } - type args struct { - in0 context.Context - val values.Values - } - tests := []struct { - name string - fields fields - args args - wantErr bool - }{ { - name: "ok: save values", + name: "ok: with key prefix", fields: fields{ - values: values.Values{}, + caches: make(map[string]*cacheData), + expireDuration: notExpireDuration, + keyPrefix: "prefix_", }, args: args{ - in0: context.Background(), - val: values.Values{ - "key": "value", - }, + key: "prefix-key", + value: stringPtr("prefix-value"), }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m := &memoryCache{ - values: tt.fields.values, + caches: tt.fields.caches, + expireDuration: tt.fields.expireDuration, + keyPrefix: tt.fields.keyPrefix, } - if err := m.Save(tt.args.in0, tt.args.val); (err != nil) != tt.wantErr { + // Save + if err := m.Save(tt.args.in0, tt.args.key, tt.args.value, tt.args.in3); (err != nil) != tt.wantErr { t.Errorf("memoryCache.Save() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(m.values, tt.args.val) { - t.Errorf("memoryCache.Save() got = %v, want %v", m.values, tt.args.val) + + // Load + gotValue, gotExpired, err := m.Load(tt.args.in0, tt.args.key, tt.args.in3) + if (err != nil) != tt.wantErr { + t.Errorf("memoryCache.Load() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotExpired != tt.wantExpired { + t.Errorf("memoryCache.Load() expired = %v, want %v", gotExpired, tt.wantExpired) + } + if !reflect.DeepEqual(gotValue, tt.args.value) { + t.Errorf("memoryCache.Load() got = %v, want %v", gotValue, tt.args.value) } }) } diff --git a/internal/client/cache/options.go b/internal/client/cache/options.go new file mode 100644 index 0000000..24bed07 --- /dev/null +++ b/internal/client/cache/options.go @@ -0,0 +1,35 @@ +/** + * Copyright 2024 DWANGO Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cache + +// Option is configurable Cache behavior. +type Option func(*opts) + +// WithCacheKeys sets the keys to be used in the cache. +func WithCacheKeys(keys ...string) Option { + return func(o *opts) { + o.CacheKeys = keys + } +} + +type opts struct { + CacheKeys []string +} + +var defaultOpts = &opts{ + CacheKeys: nil, +} diff --git a/internal/client/cache/testdata/constructor/.6b657931_keyHash b/internal/client/cache/testdata/constructor/.6b657931_keyHash new file mode 100644 index 0000000..fc6cf43 --- /dev/null +++ b/internal/client/cache/testdata/constructor/.6b657931_keyHash @@ -0,0 +1 @@ +$2a$05$qARd8goKfJMYeY0BEjOgLOGBOH5hpas01ypsJQcTsifPBLOzdKkd. diff --git a/internal/client/cache/testdata/constructor/._keyHash b/internal/client/cache/testdata/constructor/._keyHash new file mode 100644 index 0000000..ab0bc8e --- /dev/null +++ b/internal/client/cache/testdata/constructor/._keyHash @@ -0,0 +1 @@ +$2a$05$YcS1L00dfYo1dfJOMRLIY.go1GB81hWS1ZyBcm6uHgBMYiWU1GEPK diff --git a/internal/client/cache/testdata/constructor/6b657931_key b/internal/client/cache/testdata/constructor/6b657931_key new file mode 100644 index 0000000..a0038d7 Binary files /dev/null and b/internal/client/cache/testdata/constructor/6b657931_key differ diff --git a/internal/client/cache/testdata/constructor/_key b/internal/client/cache/testdata/constructor/_key new file mode 100644 index 0000000..67d60df --- /dev/null +++ b/internal/client/cache/testdata/constructor/_key @@ -0,0 +1 @@ +jù; <뇘ì5(£*8ÏÝ.2/ Fp”á‰suÝí diff --git a/internal/client/cache/testdata/read-or-create-key/.key_hash b/internal/client/cache/testdata/read-or-create-key/.key_hash deleted file mode 100644 index a23fbc4..0000000 --- a/internal/client/cache/testdata/read-or-create-key/.key_hash +++ /dev/null @@ -1 +0,0 @@ -$2a$05$mddC5bB9QF83.WYeYbbMLu2mkywrUIBs9Bi.FXkfCAcNLH8YN8UTW \ No newline at end of file diff --git a/internal/client/cache/testdata/read-or-create-key/key b/internal/client/cache/testdata/read-or-create-key/key deleted file mode 100644 index 59941e6..0000000 --- a/internal/client/cache/testdata/read-or-create-key/key +++ /dev/null @@ -1 +0,0 @@ -a-Céæ§ym¡Ëøgyl«ÇBÚ¸|I{—·è)žÇ \ No newline at end of file diff --git a/internal/client/cache/testdata/save-and-load/656e637279707465642d6b6579 b/internal/client/cache/testdata/save-and-load/656e637279707465642d6b6579 new file mode 100644 index 0000000..3d2b52e --- /dev/null +++ b/internal/client/cache/testdata/save-and-load/656e637279707465642d6b6579 @@ -0,0 +1 @@ +4ˆ4bQÌPÒWíJ®¯ØSÛÿ°©cëEö~|ÉŠ diff --git a/internal/client/cache/testdata/save-and-load/657870697265642d6b6579 b/internal/client/cache/testdata/save-and-load/657870697265642d6b6579 new file mode 100644 index 0000000..7268620 --- /dev/null +++ b/internal/client/cache/testdata/save-and-load/657870697265642d6b6579 @@ -0,0 +1 @@ +expired-value diff --git a/internal/client/cache/testdata/save-and-load/706c61696e2d6b6579 b/internal/client/cache/testdata/save-and-load/706c61696e2d6b6579 new file mode 100644 index 0000000..efa1f0e --- /dev/null +++ b/internal/client/cache/testdata/save-and-load/706c61696e2d6b6579 @@ -0,0 +1 @@ +plain-value diff --git a/internal/client/cache/testdata/save-and-load/test_7072656669782d6b6579 b/internal/client/cache/testdata/save-and-load/test_7072656669782d6b6579 new file mode 100644 index 0000000..40d9fe5 --- /dev/null +++ b/internal/client/cache/testdata/save-and-load/test_7072656669782d6b6579 @@ -0,0 +1 @@ +prefix-value diff --git a/internal/client/cache/testdata/save-and-load/values b/internal/client/cache/testdata/save-and-load/values deleted file mode 100644 index dfc9155..0000000 --- a/internal/client/cache/testdata/save-and-load/values +++ /dev/null @@ -1 +0,0 @@ -S!ª´? ´(‰1z¢Á IôDâÞ]Q*Z±y·HÏ \ No newline at end of file diff --git a/internal/client/cache_test.go b/internal/client/cache_test.go deleted file mode 100644 index d30fc45..0000000 --- a/internal/client/cache_test.go +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2024 DWANGO Co., Ltd. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package client - -import ( - "context" - "reflect" - "testing" - - "github.com/dwango/yashiro/internal/client/cache" - "github.com/dwango/yashiro/internal/values" -) - -func Test_newClientWithCache(t *testing.T) { - type args struct { - client Client - cache cache.Cache - } - tests := []struct { - name string - args args - want Client - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := newClientWithCache(tt.args.client, tt.args.cache); !reflect.DeepEqual(got, tt.want) { - t.Errorf("newClientWithCache() = %v, want %v", got, tt.want) - } - }) - } -} - -type mockClient func(ctx context.Context, ignoreNotFound bool) (values.Values, error) - -func (m mockClient) GetValues(ctx context.Context, ignoreNotFound bool) (values.Values, error) { - return m(ctx, ignoreNotFound) -} - -type mockCache func(ctx context.Context) (values.Values, bool, error) - -func (m mockCache) Load(ctx context.Context) (values.Values, bool, error) { - return m(ctx) -} -func (m mockCache) Save(ctx context.Context, val values.Values) error { - return nil -} - -func Test_clientWithCache_GetValues(t *testing.T) { - type fields struct { - client Client - cache cache.Cache - } - type args struct { - ctx context.Context - ignoreNotFound bool - } - tests := []struct { - name string - fields fields - args args - want values.Values - wantErr bool - }{ - { - name: "ok: get values from cache", - fields: fields{ - client: mockClient(func(ctx context.Context, ignoreNotFound bool) (values.Values, error) { - return values.Values{ - "key-client": "value-client", - }, nil - }), - cache: mockCache(func(ctx context.Context) (values.Values, bool, error) { - return values.Values{ - "key-cache": "value-cache", - }, false, nil - }), - }, - args: args{ - ctx: context.Background(), - ignoreNotFound: false, - }, - want: values.Values{ - "key-cache": "value-cache", - }, - }, - { - name: "ok: get values from client(no cache)", - fields: fields{ - client: mockClient(func(ctx context.Context, ignoreNotFound bool) (values.Values, error) { - return values.Values{ - "key-client": "value-client", - }, nil - }), - cache: mockCache(func(ctx context.Context) (values.Values, bool, error) { - return values.Values{}, true, nil - }), - }, - args: args{ - ctx: context.Background(), - ignoreNotFound: false, - }, - want: values.Values{ - "key-client": "value-client", - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := &clientWithCache{ - client: tt.fields.client, - cache: tt.fields.cache, - } - got, err := c.GetValues(tt.args.ctx, tt.args.ignoreNotFound) - if (err != nil) != tt.wantErr { - t.Errorf("clientWithCache.GetValues() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("clientWithCache.GetValues() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/client/client.go b/internal/client/client.go index 652a837..682efdc 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -20,7 +20,6 @@ import ( "context" "errors" - "github.com/dwango/yashiro/internal/client/cache" "github.com/dwango/yashiro/internal/values" "github.com/dwango/yashiro/pkg/config" ) @@ -41,20 +40,12 @@ func New(cfg *config.Config) (Client, error) { var err error if cfg.Aws != nil { - client, err = newAwsClient(cfg.Aws) + client, err = newAwsClient(cfg) } if err != nil { return nil, err } - if cfg.Global.EnableCache { - cache, err := cache.New(cfg.Global.Cache) - if err != nil { - return nil, err - } - client = newClientWithCache(client, cache) - } - if client == nil { return nil, ErrNotfoundValueConfig } diff --git a/internal/client/client_test.go b/internal/client/client_test.go new file mode 100644 index 0000000..c42bc8d --- /dev/null +++ b/internal/client/client_test.go @@ -0,0 +1,90 @@ +/** + * Copyright 2024 DWANGO Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package client + +import ( + "context" + "reflect" + "testing" + + "github.com/dwango/yashiro/pkg/config" +) + +func TestNew(t *testing.T) { + type args struct { + cfg *config.Config + } + tests := []struct { + name string + args args + want Client + wantErr bool + }{ + { + name: "error: not found value config", + args: args{ + cfg: &config.Config{}, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := New(tt.args.cfg) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("New() = %v, want %v", got, tt.want) + } + }) + } +} + +var ( + mockLoadFunc = func(_ context.Context, key string, decrypt bool) (*string, bool, error) { + return stringPtr("value"), false, nil + } + mockLoadFuncNotFound = func(_ context.Context, key string, decrypt bool) (*string, bool, error) { + return nil, true, nil + } + mockLoadFuncExpired = func(_ context.Context, key string, decrypt bool) (*string, bool, error) { + return stringPtr("value"), true, nil + } + + mockSaveFunc = func(_ context.Context, key string, value *string, encrypt bool) error { + return nil + } +) + +type mockCache struct { + load func(ctx context.Context, key string, decrypt bool) (*string, bool, error) + save func(ctx context.Context, key string, value *string, encrypt bool) error +} + +func (m mockCache) Load(ctx context.Context, key string, decrypt bool) (*string, bool, error) { + return m.load(ctx, key, decrypt) +} + +func (m mockCache) Save(ctx context.Context, key string, value *string, encrypt bool) error { + return m.save(ctx, key, value, encrypt) +} + +func stringPtr(s string) *string { + return &s +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 97d442a..94d1195 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -20,6 +20,7 @@ import ( "context" "io" "os" + "time" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" @@ -48,10 +49,13 @@ const ( ) type CacheConfig struct { - Type CacheType `json:"type,omitempty"` - File FileCacheConfig `json:"file,omitempty"` + Type CacheType `json:"type"` + ExpireDuration Duration `json:"expire_duration,omitempty"` + File FileCacheConfig `json:"file,omitempty"` } +const DefaultExpireDuration time.Duration = 30 * 24 * time.Hour // 30 days + type FileCacheConfig struct { CachePath string `json:"cache_path,omitempty"` } diff --git a/pkg/config/duration.go b/pkg/config/duration.go new file mode 100644 index 0000000..7aefb5f --- /dev/null +++ b/pkg/config/duration.go @@ -0,0 +1,54 @@ +/** + * Copyright 2024 DWANGO Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package config + +import ( + "encoding/json" + "fmt" + "time" +) + +type Duration time.Duration + +func (d Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(time.Duration(d).String()) +} + +func (d *Duration) UnmarshalJSON(b []byte) error { + if len(b) == 0 || string(b) == "null" { + return nil + } + + var v any + if err := json.Unmarshal(b, &v); err != nil { + return err + } + + switch value := v.(type) { + case float64: + *d = Duration(time.Duration(value)) + return nil + case string: + duration, err := time.ParseDuration(value) + if err != nil { + return err + } + *d = Duration(duration) + return nil + default: + return fmt.Errorf("invalid duration: %s", string(b)) + } +} diff --git a/pkg/config/duration_test.go b/pkg/config/duration_test.go new file mode 100644 index 0000000..1fd84d3 --- /dev/null +++ b/pkg/config/duration_test.go @@ -0,0 +1,89 @@ +/** + * Copyright 2024 DWANGO Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package config + +import ( + "encoding/json" + "reflect" + "testing" + "time" +) + +func TestMarshalJSON(t *testing.T) { + tests := []struct { + name string + d Duration + want string + wantErr bool + }{ + { + name: "ok", + d: Duration(24 * time.Hour), + want: `"24h0m0s"`, + }, + } + for _, tt := range tests { + got, err := json.Marshal(tt.d) + if (err != nil) != tt.wantErr { + t.Errorf("json.Marshal() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(string(got), tt.want) { + t.Errorf("json.Marshal() = %v, want %v", string(got), tt.want) + } + } +} + +func TestUnmarshalJSON(t *testing.T) { + type testStruct struct { + Duration Duration `json:"duration"` + } + tests := []struct { + name string + json string + want testStruct + wantErr bool + }{ + { + name: "ok: string", + json: `{"duration": "1s"}`, + want: testStruct{ + Duration: Duration(time.Second), + }, + }, + { + name: "ok: number", + json: `{"duration": 1000000000}`, // 1s + want: testStruct{ + Duration: Duration(time.Second), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got testStruct + err := json.Unmarshal([]byte(tt.json), &got) + if (err != nil) != tt.wantErr { + t.Errorf("jsonUnmarshal() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("jsonUnmarshal() = %v, want %v", got, tt.want) + } + }) + } + +}