Skip to content

Commit

Permalink
feat: add strict typing for encryption conf
Browse files Browse the repository at this point in the history
add strict typing to enforce usage of attributes for the encryption
config that are recongnized by tofu to avoid defering error handling to
them
  • Loading branch information
norman-zon committed Nov 27, 2024
1 parent e168646 commit 41e4fc1
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 2 deletions.
1 change: 0 additions & 1 deletion codegen/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ func RemoteStateConfigToTerraformCode(backend string, config map[string]interfac
if !found {
return nil, fmt.Errorf(encryptionKeyProviderKey + " is mandatory but not found in the encryption map")
}

keyProviderTraversal := hcl.Traversal{
hcl.TraverseRoot{Name: encryptionKeyProviderKey},
hcl.TraverseAttr{Name: keyProvider},
Expand Down
89 changes: 89 additions & 0 deletions remote/remote_encryption.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package remote

import (
"fmt"

"github.com/mitchellh/mapstructure"
)

type RemoteEncryptionConfig interface {
UnmarshalConfig(encryptionConfig map[string]interface{}) error
ToMap() (map[string]interface{}, error)
}

type RemoteEncryptionKeyProvider interface {
RemoteEncryptionKeyProviderPBKDF2 | RemoteEncryptionKeyProviderGCPKMS | RemoteEncryptionKeyProviderAWSKMS
}

type RemoteEncryptionKeyProviderBase struct {
KeyProvider string `mapstructure:"key_provider"`
}

type GenericRemoteEncryptionKeyProvider[T RemoteEncryptionKeyProvider] struct {
T T
}

func (b *GenericRemoteEncryptionKeyProvider[T]) UnmarshalConfig(encryptionConfig map[string]interface{}) error {
// Decode the key provider type using the default decoder config
if err := mapstructure.Decode(encryptionConfig, &b); err != nil {
return fmt.Errorf("failed to decode key provider: %w", err)
}

// Decode the key provider properties using, setting ErrorUnused to true to catch any unused properties
decoderConfig := &mapstructure.DecoderConfig{
Result: &b.T,
ErrorUnused: true,
}
decoder, err := mapstructure.NewDecoder(decoderConfig)
if err != nil {
return fmt.Errorf("failed to create decoder: %w", err)
}
if err := decoder.Decode(encryptionConfig); err != nil {
return fmt.Errorf("failed to decode key provider properties: %w", err)
}

return nil
}

func (b *GenericRemoteEncryptionKeyProvider[T]) ToMap() (map[string]interface{}, error) {
var result map[string]interface{}
err := mapstructure.Decode(b.T, &result)
if err != nil {
return nil, fmt.Errorf("failed to decode struct to map: %w", err)
}
return result, nil
}

func NewRemoteEncryptionKeyProvider(providerType string) (RemoteEncryptionConfig, error) {
switch providerType {
case "pbkdf2":
return &GenericRemoteEncryptionKeyProvider[RemoteEncryptionKeyProviderPBKDF2]{}, nil
case "gcp_kms":
return &GenericRemoteEncryptionKeyProvider[RemoteEncryptionKeyProviderGCPKMS]{}, nil
case "aws_kms":
return &GenericRemoteEncryptionKeyProvider[RemoteEncryptionKeyProviderAWSKMS]{}, nil
default:
return nil, fmt.Errorf("unknown provider type: %s", providerType)
}
}

type RemoteEncryptionKeyProviderPBKDF2 struct {
RemoteEncryptionKeyProviderBase `mapstructure:",squash"`
Passphrase string `mapstructure:"passphrase"`
KeyLength int `mapstructure:"key_length"`
Iterations int `mapstructure:"iterations"`
SaltLength int `mapstructure:"salt_length"`
HashFunction string `mapstructure:"hash_function"`
}

type RemoteEncryptionKeyProviderAWSKMS struct {
RemoteEncryptionKeyProviderBase `mapstructure:",squash"`
KmsKeyId int `mapstructure:"kms_key_id"`
KeySpec string `mapstructure:"key_spec"`
}

type RemoteEncryptionKeyProviderGCPKMS struct {
RemoteEncryptionKeyProviderBase `mapstructure:",squash"`
KmsEncryptionKey string `mapstructure:"kms_encryption_key"`
KeyLength int `mapstructure:"key_length"`
}
206 changes: 206 additions & 0 deletions remote/remote_encryption_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package remote_test

import (
"testing"

"github.com/gruntwork-io/terragrunt/remote"
"github.com/stretchr/testify/assert"
)

func TestUnmarshalConfig(t *testing.T) {
tests := []struct {
name string
providerType string
encryptionConfig map[string]interface{}
expectedError bool
}{
{
name: "PBKDF2 valid config",
providerType: "pbkdf2",
encryptionConfig: map[string]interface{}{
"key_provider": "pbkdf2",
"passphrase": "passphrase",
"key_length": 32,
"iterations": 10000,
"salt_length": 16,
"hash_function": "sha256",
},
expectedError: false,
},
{
name: "PBKDF2 invalid property",
providerType: "pbkdf2",
encryptionConfig: map[string]interface{}{
"key_provider": "pbkdf2",
"password": "password123", // Invalid property
},
expectedError: true,
},
{
name: "PBKDF2 invalid config",
providerType: "pbkdf2",
encryptionConfig: map[string]interface{}{
"key_provider": "pbkdf2",
"passphrase": 123, // Invalid type
},
expectedError: true,
},
{
name: "AWSKMS valid config",
providerType: "aws_kms",
encryptionConfig: map[string]interface{}{
"key_provider": "aws_kms",
"kms_key_id": 123456789,
"key_spec": "AES_256",
},
expectedError: false,
},
{
name: "AWSKMS invalid property",
providerType: "aws_kms",
encryptionConfig: map[string]interface{}{
"key_provider": "aws_kms",
"password": "password123", // Invalid property
},
expectedError: true,
},
{
name: "AWSKMS invalid config",
providerType: "aws_kms",
encryptionConfig: map[string]interface{}{
"key_provider": "aws_kms",
"kms_key_id": "invalid_id", // Invalid type
"key_spec": "AES_256",
},
expectedError: true,
},
{
name: "GCPKMS valid config",
providerType: "gcp_kms",
encryptionConfig: map[string]interface{}{
"key_provider": "gcp_kms",
"kms_encryption_key": "projects/123456789/locations/global/keyRings/my-key-ring/cryptoKeys/my-key",
"key_length": 32,
},
expectedError: false,
},
{
name: "GCPKMS invalid property",
providerType: "gcp_kms",
encryptionConfig: map[string]interface{}{
"key_provider": "gcp_kms",
"password": "password123", // Invalid property
},
expectedError: true,
},
{
name: "GCPKMS invalid config",
providerType: "gcp_kms",
encryptionConfig: map[string]interface{}{
"key_provider": "gcp_kms",
"kms_encryption_key": 123456789, // Invalid type
"key_length": 32,
},
expectedError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, err := remote.NewRemoteEncryptionKeyProvider(tt.providerType)
if err != nil {
t.Fatalf("failed to create provider: %v", err)
}

err = provider.UnmarshalConfig(tt.encryptionConfig)
if tt.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestToMap(t *testing.T) {
tests := []struct {
name string
providerType string
encryptionConfig map[string]interface{}
expectedMap map[string]interface{}
expectedError bool
}{
{
name: "PBKDF2 valid config",
providerType: "pbkdf2",
encryptionConfig: map[string]interface{}{
"key_provider": "pbkdf2",
"passphrase": "passphrase",
"key_length": 32,
"iterations": 10000,
"salt_length": 16,
"hash_function": "sha256",
},
expectedMap: map[string]interface{}{
"key_provider": "pbkdf2",
"passphrase": "passphrase",
"key_length": 32,
"iterations": 10000,
"salt_length": 16,
"hash_function": "sha256",
},
expectedError: false,
},
{
name: "AWSKMS valid config",
providerType: "aws_kms",
encryptionConfig: map[string]interface{}{
"key_provider": "aws_kms",
"kms_key_id": 123456789,
"key_spec": "AES_256",
},
expectedMap: map[string]interface{}{
"key_provider": "aws_kms",
"kms_key_id": 123456789,
"key_spec": "AES_256",
},
expectedError: false,
},
{
name: "GCPKMS valid config",
providerType: "gcp_kms",
encryptionConfig: map[string]interface{}{
"key_provider": "gcp_kms",
"kms_encryption_key": "projects/123456789/locations/global/keyRings/my-key-ring/cryptoKeys/my-key",
"key_length": 32,
},
expectedMap: map[string]interface{}{
"key_provider": "gcp_kms",
"kms_encryption_key": "projects/123456789/locations/global/keyRings/my-key-ring/cryptoKeys/my-key",
"key_length": 32,
},
expectedError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, err := remote.NewRemoteEncryptionKeyProvider(tt.providerType)
if err != nil {
t.Fatalf("failed to create provider: %v", err)
}

err = provider.UnmarshalConfig(tt.encryptionConfig)
if err != nil {
t.Fatalf("failed to unmarshal config: %v", err)
}

result, err := provider.ToMap()
if tt.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedMap, result)
}
})
}
}
21 changes: 20 additions & 1 deletion remote/remote_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,26 @@ func (state *RemoteState) GenerateTerraformCode(terragruntOptions *options.Terra
// Make sure to strip out terragrunt specific configurations from the config.
config := state.Config

encryption := state.Encryption
// Initialize the encryption config based on the key provider
keyProvider, ok := state.Encryption["key_provider"].(string)
if !ok {
return errors.New("key_provider not found in encryption config")
}

encryptionProvider, err := NewRemoteEncryptionKeyProvider(keyProvider)
if err != nil {
return fmt.Errorf("error creating provider: %v", err)
}

err = encryptionProvider.UnmarshalConfig(state.Encryption)
if err != nil {
return err
}

encryption, err := encryptionProvider.ToMap()
if err != nil {
return fmt.Errorf("error decoding struct to map: %v", err)
}

initializer, hasInitializer := remoteStateInitializers[state.Backend]
if hasInitializer {
Expand Down

0 comments on commit 41e4fc1

Please sign in to comment.