Skip to content

Commit f9a4fe1

Browse files
author
Vivek Reddy
committed
CSPL-3974: improve region parsing and enforce latest version downloads
1 parent 0f64974 commit f9a4fe1

File tree

2 files changed

+179
-41
lines changed

2 files changed

+179
-41
lines changed

pkg/splunk/client/awss3client.go

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"fmt"
2323
"io"
2424
"net/http"
25+
"net/url"
2526
"os"
2627
"regexp"
2728
"strings"
@@ -42,6 +43,7 @@ var _ RemoteDataClient = &AWSS3Client{}
4243
type SplunkAWSS3Client interface {
4344
ListObjectsV2(ctx context.Context, input *s3.ListObjectsV2Input, options ...func(*s3.Options)) (*s3.ListObjectsV2Output, error)
4445
GetObject(ctx context.Context, input *s3.GetObjectInput, options ...func(*s3.Options)) (*s3.GetObjectOutput, error)
46+
HeadObject(ctx context.Context, input *s3.HeadObjectInput, options ...func(*s3.Options)) (*s3.HeadObjectOutput, error)
4547
}
4648

4749
// SplunkAWSDownloadClient is used to download the apps from remote storage
@@ -62,18 +64,34 @@ type AWSS3Client struct {
6264
Downloader SplunkAWSDownloadClient
6365
}
6466

65-
var regionRegex = ".*.s3[-,.]([a-z]+-[a-z]+-[0-9]+)\\..*amazonaws.com"
67+
var regionRegex = `(?i)(^|\.)(s3)[\.-]([a-z0-9-]+)\.amazonaws\.com$`
6668

6769
// GetRegion extracts the region from the endpoint field
6870
func GetRegion(ctx context.Context, endpoint string, region *string) error {
69-
var err error
70-
pattern := regexp.MustCompile(regionRegex)
71-
if len(pattern.FindStringSubmatch(endpoint)) > 1 {
72-
*region = pattern.FindStringSubmatch(endpoint)[1]
73-
} else {
74-
err = fmt.Errorf("unable to extract region from the endpoint")
75-
}
76-
return err
71+
var host string
72+
73+
// If endpoint looks like a URL, extract the hostname; otherwise use raw string.
74+
if u, err := url.Parse(endpoint); err == nil && u.Hostname() != "" {
75+
host = u.Hostname()
76+
} else {
77+
// tolerate raw host (with or without scheme)
78+
host = endpoint
79+
// strip a possible leading scheme manually if present
80+
if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") {
81+
if u2, err2 := url.Parse(host); err2 == nil && u2.Hostname() != "" {
82+
host = u2.Hostname()
83+
}
84+
}
85+
}
86+
87+
pattern := regexp.MustCompile(regionRegex)
88+
m := pattern.FindStringSubmatch(host)
89+
if len(m) >= 4 {
90+
// capture group 3 is the region
91+
*region = m[3]
92+
return nil
93+
}
94+
return fmt.Errorf("unable to extract region from the endpoint: %q (host: %q)", endpoint, host)
7795
}
7896

7997
// InitAWSClientWrapper is a wrapper around InitClientConfig
@@ -184,6 +202,7 @@ func NewAWSS3Client(ctx context.Context, bucketName string, accessKeyID string,
184202
Prefix: prefix,
185203
StartAfter: startAfter,
186204
Client: s3SplunkClient,
205+
Endpoint: endpoint,
187206
Downloader: downloader,
188207
}, nil
189208
}
@@ -258,6 +277,14 @@ func (awsclient *AWSS3Client) DownloadApp(ctx context.Context, downloadRequest R
258277
scopedLog := reqLogger.WithName("DownloadApp").WithValues("remoteFile", downloadRequest.RemoteFile, "localFile",
259278
downloadRequest.LocalFile, "etag", downloadRequest.Etag)
260279

280+
// Validate inputs early, avoid calling downloader with bad args.
281+
if strings.TrimSpace(downloadRequest.LocalFile) == "" {
282+
return false, fmt.Errorf("local file path is empty")
283+
}
284+
if strings.TrimSpace(downloadRequest.RemoteFile) == "" {
285+
return false, fmt.Errorf("remote file key is empty")
286+
}
287+
261288
var numBytes int64
262289
file, err := os.Create(downloadRequest.LocalFile)
263290
if err != nil {
@@ -266,16 +293,35 @@ func (awsclient *AWSS3Client) DownloadApp(ctx context.Context, downloadRequest R
266293
}
267294
defer file.Close()
268295

296+
// Optional preflight: if the caller gave us an ETag, check the current one.
297+
// We still download even if it differs, we just log for visibility.
298+
if downloadRequest.Etag != "" {
299+
if head, herr := awsclient.Client.HeadObject(ctx, &s3.HeadObjectInput{
300+
Bucket: aws.String(awsclient.BucketName),
301+
Key: aws.String(downloadRequest.RemoteFile),
302+
}); herr == nil {
303+
current := aws.ToString(head.ETag)
304+
if current != "" && current != downloadRequest.Etag {
305+
scopedLog.Info("Provided ETag differs from current ETag in S3, will download latest",
306+
"providedEtag", downloadRequest.Etag, "currentEtag", current)
307+
}
308+
} else {
309+
scopedLog.Info("HeadObject failed, proceeding to download", "error", herr.Error())
310+
}
311+
}
312+
313+
getIn := &s3.GetObjectInput{
314+
Bucket: aws.String(awsclient.BucketName),
315+
Key: aws.String(downloadRequest.RemoteFile),
316+
// Intentionally no IfMatch — avoids 412 PreconditionFailed when ETag is stale.
317+
}
318+
269319
downloader := awsclient.Downloader
270-
numBytes, err = downloader.Download(ctx, file,
271-
&s3.GetObjectInput{
272-
Bucket: aws.String(awsclient.BucketName),
273-
Key: aws.String(downloadRequest.RemoteFile),
274-
IfMatch: aws.String(downloadRequest.Etag),
275-
})
320+
numBytes, err = downloader.Download(ctx, file, getIn)
276321
if err != nil {
277322
scopedLog.Error(err, "Unable to download item", "RemoteFile", downloadRequest.RemoteFile)
278-
os.Remove(downloadRequest.RemoteFile)
323+
// Remove the partially written local file, not the remote key.
324+
_ = os.Remove(downloadRequest.LocalFile)
279325
return false, err
280326
}
281327

0 commit comments

Comments
 (0)