Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 62 additions & 16 deletions pkg/splunk/client/awss3client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"regexp"
"strings"
Expand All @@ -42,6 +43,7 @@ var _ RemoteDataClient = &AWSS3Client{}
type SplunkAWSS3Client interface {
ListObjectsV2(ctx context.Context, input *s3.ListObjectsV2Input, options ...func(*s3.Options)) (*s3.ListObjectsV2Output, error)
GetObject(ctx context.Context, input *s3.GetObjectInput, options ...func(*s3.Options)) (*s3.GetObjectOutput, error)
HeadObject(ctx context.Context, input *s3.HeadObjectInput, options ...func(*s3.Options)) (*s3.HeadObjectOutput, error)
}

// SplunkAWSDownloadClient is used to download the apps from remote storage
Expand All @@ -62,18 +64,34 @@ type AWSS3Client struct {
Downloader SplunkAWSDownloadClient
}

var regionRegex = ".*.s3[-,.]([a-z]+-[a-z]+-[0-9]+)\\..*amazonaws.com"
var regionRegex = `(?i)(^|\.)(s3)[\.-]([a-z0-9-]+)\.amazonaws\.com$`
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regex pattern uses a case-insensitive flag (?i) but then only matches lowercase characters in the region group [a-z0-9-]+. Consider either removing the case-insensitive flag or updating the character class to include uppercase letters [a-zA-Z0-9-]+ for consistency.

Suggested change
var regionRegex = `(?i)(^|\.)(s3)[\.-]([a-z0-9-]+)\.amazonaws\.com$`
var regionRegex = `(?i)(^|\.)(s3)[\.-]([a-zA-Z0-9-]+)\.amazonaws\.com$`

Copilot uses AI. Check for mistakes.

// GetRegion extracts the region from the endpoint field
func GetRegion(ctx context.Context, endpoint string, region *string) error {
var err error
pattern := regexp.MustCompile(regionRegex)
if len(pattern.FindStringSubmatch(endpoint)) > 1 {
*region = pattern.FindStringSubmatch(endpoint)[1]
} else {
err = fmt.Errorf("unable to extract region from the endpoint")
}
return err
var host string

// If endpoint looks like a URL, extract the hostname; otherwise use raw string.
if u, err := url.Parse(endpoint); err == nil && u.Hostname() != "" {
host = u.Hostname()
} else {
// tolerate raw host (with or without scheme)
host = endpoint
// strip a possible leading scheme manually if present
if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") {
if u2, err2 := url.Parse(host); err2 == nil && u2.Hostname() != "" {
host = u2.Hostname()
}
}
}

pattern := regexp.MustCompile(regionRegex)
m := pattern.FindStringSubmatch(host)
if len(m) >= 4 {
// capture group 3 is the region
*region = m[3]
return nil
}
return fmt.Errorf("unable to extract region from the endpoint: %q (host: %q)", endpoint, host)
}

// InitAWSClientWrapper is a wrapper around InitClientConfig
Expand Down Expand Up @@ -184,6 +202,7 @@ func NewAWSS3Client(ctx context.Context, bucketName string, accessKeyID string,
Prefix: prefix,
StartAfter: startAfter,
Client: s3SplunkClient,
Endpoint: endpoint,
Downloader: downloader,
}, nil
}
Expand Down Expand Up @@ -258,6 +277,14 @@ func (awsclient *AWSS3Client) DownloadApp(ctx context.Context, downloadRequest R
scopedLog := reqLogger.WithName("DownloadApp").WithValues("remoteFile", downloadRequest.RemoteFile, "localFile",
downloadRequest.LocalFile, "etag", downloadRequest.Etag)

// Validate inputs early, avoid calling downloader with bad args.
if strings.TrimSpace(downloadRequest.LocalFile) == "" {
return false, fmt.Errorf("local file path is empty")
}
if strings.TrimSpace(downloadRequest.RemoteFile) == "" {
return false, fmt.Errorf("remote file key is empty")
}

var numBytes int64
file, err := os.Create(downloadRequest.LocalFile)
if err != nil {
Expand All @@ -266,16 +293,35 @@ func (awsclient *AWSS3Client) DownloadApp(ctx context.Context, downloadRequest R
}
defer file.Close()

// Optional preflight: if the caller gave us an ETag, check the current one.
// We still download even if it differs, we just log for visibility.
if downloadRequest.Etag != "" {
if head, herr := awsclient.Client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(awsclient.BucketName),
Key: aws.String(downloadRequest.RemoteFile),
}); herr == nil {
current := aws.ToString(head.ETag)
if current != "" && current != downloadRequest.Etag {
scopedLog.Info("Provided ETag differs from current ETag in S3, will download latest",
"providedEtag", downloadRequest.Etag, "currentEtag", current)
}
} else {
scopedLog.Info("HeadObject failed, proceeding to download", "error", herr.Error())
}
}

getIn := &s3.GetObjectInput{
Bucket: aws.String(awsclient.BucketName),
Key: aws.String(downloadRequest.RemoteFile),
// Intentionally no IfMatch — avoids 412 PreconditionFailed when ETag is stale.
}

downloader := awsclient.Downloader
numBytes, err = downloader.Download(ctx, file,
&s3.GetObjectInput{
Bucket: aws.String(awsclient.BucketName),
Key: aws.String(downloadRequest.RemoteFile),
IfMatch: aws.String(downloadRequest.Etag),
})
numBytes, err = downloader.Download(ctx, file, getIn)
if err != nil {
scopedLog.Error(err, "Unable to download item", "RemoteFile", downloadRequest.RemoteFile)
os.Remove(downloadRequest.RemoteFile)
// Remove the partially written local file, not the remote key.
_ = os.Remove(downloadRequest.LocalFile)
return false, err
}

Expand Down
Loading
Loading