diff --git a/pkg/cloudmeta/aws.go b/pkg/cloudmeta/aws.go new file mode 100644 index 000000000..c50b0dd46 --- /dev/null +++ b/pkg/cloudmeta/aws.go @@ -0,0 +1,44 @@ +// Copyright (C) 2024 ScyllaDB + +package cloudmeta + +import ( + "context" + + "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/pkg/errors" +) + +// awsMetadata is a wrapper around ec2 metadata client. +type awsMetadata struct { + ec2meta *ec2metadata.EC2Metadata +} + +// newAWSMetadata is a constructor for AWSMetadata service. +func newAWSMetadata() (*awsMetadata, error) { + session, err := session.NewSession() + if err != nil { + return nil, errors.Wrap(err, "session.NewSession") + } + return &awsMetadata{ + ec2meta: ec2metadata.New(session), + }, nil +} + +// Metadata return InstanceMetadata from aws if available. +func (aws *awsMetadata) Metadata(ctx context.Context) (InstanceMetadata, error) { + if !aws.ec2meta.AvailableWithContext(ctx) { + return InstanceMetadata{}, errors.New("metadata is not available") + } + + instanceData, err := aws.ec2meta.GetInstanceIdentityDocumentWithContext(ctx) + if err != nil { + return InstanceMetadata{}, errors.Wrap(err, "aws.metadataClient.GetInstanceIdentityDocument") + } + + return InstanceMetadata{ + CloudProvider: CloudProviderAWS, + InstanceType: instanceData.InstanceType, + }, nil +} diff --git a/pkg/cloudmeta/metadata.go b/pkg/cloudmeta/metadata.go new file mode 100644 index 000000000..a0507248c --- /dev/null +++ b/pkg/cloudmeta/metadata.go @@ -0,0 +1,107 @@ +// Copyright (C) 2024 ScyllaDB + +package cloudmeta + +import ( + "context" + "time" + + "github.com/pkg/errors" + "go.uber.org/multierr" +) + +// InstanceMetadata represents metadata returned by cloud provider. +type InstanceMetadata struct { + InstanceType string + CloudProvider CloudProvider +} + +// CloudProvider is enum of supported cloud providers. +type CloudProvider string + +// CloudProviderAWS represents aws provider. +var CloudProviderAWS CloudProvider = "aws" + +// CloudMetadataProvider interface that each metadata provider should implement. +type CloudMetadataProvider interface { + Metadata(ctx context.Context) (InstanceMetadata, error) +} + +// CloudMeta is a wrapper around various cloud metadata providers. +type CloudMeta struct { + providers []CloudMetadataProvider + + providerTimeout time.Duration +} + +// NewCloudMeta creates new CloudMeta provider. +func NewCloudMeta() (*CloudMeta, error) { + const defaultTimeout = 5 * time.Second + + awsMeta, err := newAWSMetadata() + if err != nil { + return nil, err + } + + return &CloudMeta{ + providers: []CloudMetadataProvider{ + awsMeta, + }, + providerTimeout: defaultTimeout, + }, nil +} + +// ErrNoProviders will be returned by CloudMeta service, when it hasn't been initialized with any metadata provider. +var ErrNoProviders = errors.New("no metadata providers found") + +// GetInstanceMetadata tries to fetch instance metadata from AWS, GCP, Azure concurrently and returns first result. +func (cloud *CloudMeta) GetInstanceMetadata(ctx context.Context) (InstanceMetadata, error) { + if len(cloud.providers) == 0 { + return InstanceMetadata{}, errors.WithStack(ErrNoProviders) + } + + type msg struct { + meta InstanceMetadata + err error + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + results := make(chan msg, len(cloud.providers)) + for _, provider := range cloud.providers { + go func(provider CloudMetadataProvider) { + meta, err := cloud.runWithTimeout(ctx, provider) + + select { + case <-ctx.Done(): + return + case results <- msg{meta: meta, err: err}: + } + }(provider) + } + + // Return the first non error result or wait until all providers return err. + var mErr error + for range len(cloud.providers) { + res := <-results + if res.err != nil { + mErr = multierr.Append(mErr, res.err) + continue + } + return res.meta, nil + } + return InstanceMetadata{}, mErr +} + +func (cloud *CloudMeta) runWithTimeout(ctx context.Context, provider CloudMetadataProvider) (InstanceMetadata, error) { + ctx, cancel := context.WithTimeout(ctx, cloud.providerTimeout) + defer cancel() + + return provider.Metadata(ctx) +} + +// WithProviderTimeout sets per provider timeout. +func (cloud *CloudMeta) WithProviderTimeout(providerTimeout time.Duration) { + cloud.providerTimeout = providerTimeout +} diff --git a/pkg/cloudmeta/metadata_test.go b/pkg/cloudmeta/metadata_test.go new file mode 100644 index 000000000..363819400 --- /dev/null +++ b/pkg/cloudmeta/metadata_test.go @@ -0,0 +1,132 @@ +// Copyright (C) 2024 ScyllaDB + +package cloudmeta + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestGetInstanceMetadata(t *testing.T) { + testCases := []struct { + name string + providers []CloudMetadataProvider + + expectedErr bool + expectedMeta InstanceMetadata + }{ + { + name: "when there is no active providers", + providers: nil, + + expectedErr: true, + expectedMeta: InstanceMetadata{}, + }, + { + name: "when there is one active providers", + providers: []CloudMetadataProvider{ + newTestProvider(t, "test_provider_1", "x-test-1", 1*time.Millisecond, nil), + }, + + expectedErr: false, + expectedMeta: InstanceMetadata{ + CloudProvider: "test_provider_1", + InstanceType: "x-test-1", + }, + }, + { + name: "when there is more than one active provider, fastest should be returned", + providers: []CloudMetadataProvider{ + newTestProvider(t, "test_provider_1", "x-test-1", 1*time.Millisecond, nil), + newTestProvider(t, "test_provider_2", "x-test-2", 2*time.Millisecond, nil), + }, + + expectedErr: false, + expectedMeta: InstanceMetadata{ + CloudProvider: "test_provider_1", + InstanceType: "x-test-1", + }, + }, + { + name: "when there is more than one active provider, but fastest returns err", + providers: []CloudMetadataProvider{ + newTestProvider(t, "test_provider_1", "x-test-1", 1*time.Millisecond, errors.New("something went wront")), + newTestProvider(t, "test_provider_2", "x-test-2", 2*time.Millisecond, nil), + }, + + expectedErr: false, + expectedMeta: InstanceMetadata{ + CloudProvider: "test_provider_2", + InstanceType: "x-test-2", + }, + }, + { + name: "when there is more than one active provider, but all returns err", + providers: []CloudMetadataProvider{ + newTestProvider(t, "test_provider_1", "x-test-1", 1*time.Millisecond, errors.New("err provider1")), + newTestProvider(t, "test_provider_2", "x-test-2", 1*time.Millisecond, errors.New("err provider2")), + }, + + expectedErr: true, + expectedMeta: InstanceMetadata{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cloudmeta := &CloudMeta{ + providers: tc.providers, + } + + meta, err := cloudmeta.GetInstanceMetadata(context.Background()) + + if tc.expectedErr && err == nil { + t.Fatalf("expected error, got: %v", err) + } + + if !tc.expectedErr && err != nil { + t.Fatalf("unexpected error, got: %v", err) + } + + if tc.expectedMeta.InstanceType != meta.InstanceType { + t.Fatalf("unexpected meta.InstanceType: %s != %s", tc.expectedMeta.InstanceType, meta.InstanceType) + } + + if tc.expectedMeta.CloudProvider != meta.CloudProvider { + t.Fatalf("unexpected meta.CloudProvider: %s != %s", tc.expectedMeta.CloudProvider, meta.CloudProvider) + } + }) + } +} + +func newTestProvider(t *testing.T, providerName, instanceType string, latency time.Duration, err error) *testProvider { + t.Helper() + + return &testProvider{ + name: CloudProvider(providerName), + instanceType: instanceType, + latency: latency, + err: err, + } +} + +type testProvider struct { + name CloudProvider + instanceType string + latency time.Duration + err error +} + +func (tp testProvider) Metadata(ctx context.Context) (InstanceMetadata, error) { + time.Sleep(tp.latency) + + if tp.err != nil { + return InstanceMetadata{}, tp.err + } + return InstanceMetadata{ + CloudProvider: tp.name, + InstanceType: tp.instanceType, + }, nil +}