Skip to content

Commit

Permalink
SNOW-1789753: Support GCS region specific endpoint (#1280)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-ext-simba-jy authored Dec 20, 2024
1 parent 7d34091 commit 8257f91
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 18 deletions.
27 changes: 21 additions & 6 deletions gcs_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
gcsMetadataMatdescKey = gcsMetadataPrefix + "matdesc"
gcsMetadataEncryptionDataProp = gcsMetadataPrefix + "encryptiondata"
gcsFileHeaderDigest = "gcs-file-header-digest"
gcsRegionMeCentral2 = "me-central2"
)

type snowflakeGcsClient struct {
Expand Down Expand Up @@ -52,7 +53,7 @@ func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename strin
if meta.presignedURL != nil {
meta.resStatus = notFoundFile
} else {
URL, err := util.generateFileURL(meta.stageInfo.Location, strings.TrimLeft(filename, "/"))
URL, err := util.generateFileURL(meta.stageInfo, strings.TrimLeft(filename, "/"))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -147,7 +148,7 @@ func (util *snowflakeGcsClient) uploadFile(
var err error

if uploadURL == nil {
uploadURL, err = util.generateFileURL(meta.stageInfo.Location, strings.TrimLeft(meta.dstFileName, "/"))
uploadURL, err = util.generateFileURL(meta.stageInfo, strings.TrimLeft(meta.dstFileName, "/"))
if err != nil {
return err
}
Expand Down Expand Up @@ -279,7 +280,7 @@ func (util *snowflakeGcsClient) nativeDownloadFile(
gcsHeaders := make(map[string]string)

if downloadURL == nil || downloadURL.String() == "" {
downloadURL, err = util.generateFileURL(meta.stageInfo.Location, strings.TrimLeft(meta.srcFileName, "/"))
downloadURL, err = util.generateFileURL(meta.stageInfo, strings.TrimLeft(meta.srcFileName, "/"))
if err != nil {
return err
}
Expand Down Expand Up @@ -388,10 +389,11 @@ func (util *snowflakeGcsClient) extractBucketNameAndPath(location string) *gcsLo
return &gcsLocation{containerName, path}
}

func (util *snowflakeGcsClient) generateFileURL(stageLocation string, filename string) (*url.URL, error) {
gcsLoc := util.extractBucketNameAndPath(stageLocation)
func (util *snowflakeGcsClient) generateFileURL(stageInfo *execResponseStageInfo, filename string) (*url.URL, error) {
gcsLoc := util.extractBucketNameAndPath(stageInfo.Location)
fullFilePath := gcsLoc.path + filename
URL, err := url.Parse("https://storage.googleapis.com/" + gcsLoc.bucketName + "/" + url.QueryEscape(fullFilePath))
endPoint := getGcsCustomEndpoint(stageInfo)
URL, err := url.Parse(endPoint + "/" + gcsLoc.bucketName + "/" + url.QueryEscape(fullFilePath))
if err != nil {
return nil, err
}
Expand All @@ -407,3 +409,16 @@ func newGcsClient() gcsAPI {
Transport: SnowflakeTransport,
}
}

func getGcsCustomEndpoint(info *execResponseStageInfo) string {
endpoint := "https://storage.googleapis.com"

// TODO: SNOW-1789759 hardcoded region will be replaced in the future
isRegionalURLEnabled := (strings.ToLower(info.Region) == gcsRegionMeCentral2) || info.UseRegionalURL
if info.EndPoint != "" {
endpoint = fmt.Sprintf("https://%s", info.EndPoint)
} else if info.Region != "" && isRegionalURLEnabled {
endpoint = fmt.Sprintf("https://storage.%s.rep.googleapis.com", strings.ToLower(info.Region))
}
return endpoint
}
94 changes: 93 additions & 1 deletion gcs_storage_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ func TestGenerateFileURL(t *testing.T) {
}
for _, test := range testcases {
t.Run(test.location, func(t *testing.T) {
gcsURL, err := gcsUtil.generateFileURL(test.location, test.fname)
stageInfo := &execResponseStageInfo{}
stageInfo.Location = test.location
gcsURL, err := gcsUtil.generateFileURL(stageInfo, test.fname)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -1126,3 +1128,93 @@ func Test_snowflakeGcsClient_nativeDownloadFile(t *testing.T) {
t.Error("should have raised an error")
}
}

func TestGetGcsCustomEndpoint(t *testing.T) {
testcases := []struct {
desc string
in execResponseStageInfo
out string
}{
{
desc: "when the endPoint is not specified and UseRegionalURL is false",
in: execResponseStageInfo{
UseRegionalURL: false,
EndPoint: "",
Region: "WEST-1",
},
out: "https://storage.googleapis.com",
},
{
desc: "when the useRegionalURL is only enabled",
in: execResponseStageInfo{
UseRegionalURL: true,
EndPoint: "",
Region: "mockLocation",
},
out: "https://storage.mocklocation.rep.googleapis.com",
},
{
desc: "when the region is me-central2",
in: execResponseStageInfo{
UseRegionalURL: false,
EndPoint: "",
Region: "me-central2",
},
out: "https://storage.me-central2.rep.googleapis.com",
},
{
desc: "when the region is me-central2 (mixed case)",
in: execResponseStageInfo{
UseRegionalURL: false,
EndPoint: "",
Region: "ME-cEntRal2",
},
out: "https://storage.me-central2.rep.googleapis.com",
},
{
desc: "when the region is me-central2 (uppercase)",
in: execResponseStageInfo{
UseRegionalURL: false,
EndPoint: "",
Region: "ME-CENTRAL2",
},
out: "https://storage.me-central2.rep.googleapis.com",
},
{
desc: "when the endPoint is specified",
in: execResponseStageInfo{
UseRegionalURL: false,
EndPoint: "storage.specialEndPoint.rep.googleapis.com",
Region: "ME-cEntRal1",
},
out: "https://storage.specialEndPoint.rep.googleapis.com",
},
{
desc: "when both the endPoint and the useRegionalUrl are specified",
in: execResponseStageInfo{
UseRegionalURL: true,
EndPoint: "storage.specialEndPoint.rep.googleapis.com",
Region: "ME-cEntRal1",
},
out: "https://storage.specialEndPoint.rep.googleapis.com",
},
{
desc: "when both the endPoint is specified and the region is me-central2",
in: execResponseStageInfo{
UseRegionalURL: true,
EndPoint: "storage.specialEndPoint.rep.googleapis.com",
Region: "ME-CENTRAL2",
},
out: "https://storage.specialEndPoint.rep.googleapis.com",
},
}

for _, test := range testcases {
t.Run(test.desc, func(t *testing.T) {
endpoint := getGcsCustomEndpoint(&test.in)
if endpoint != test.out {
t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out, endpoint)
}
})
}
}
2 changes: 2 additions & 0 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ type execResponseStageInfo struct {
Creds execResponseCredentials `json:"creds,omitempty"`
PresignedURL string `json:"presignedUrl,omitempty"`
EndPoint string `json:"endPoint,omitempty"`
UseS3RegionalURL bool `json:"useS3RegionalUrl,omitempty"`
UseRegionalURL bool `json:"useRegionalUrl,omitempty"`
}

// make all data field optional
Expand Down
35 changes: 24 additions & 11 deletions s3_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@ import (
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/smithy-go"
"github.com/aws/smithy-go/logging"
"io"
"net/http"
"os"
"strings"
)

const (
Expand Down Expand Up @@ -47,20 +48,15 @@ var S3LoggingMode aws.ClientLogMode
func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
stageCredentials := info.Creds
s3Logger := logging.LoggerFunc(s3LoggingFunc)

var endpoint *string
if info.EndPoint != "" {
tmp := "https://" + info.EndPoint
endpoint = &tmp
}
endPoint := getS3CustomEndpoint(info)

return s3.New(s3.Options{
Region: info.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(
stageCredentials.AwsKeyID,
stageCredentials.AwsSecretKey,
stageCredentials.AwsToken)),
BaseEndpoint: endpoint,
BaseEndpoint: endPoint,
UseAccelerate: useAccelerateEndpoint,
HTTPClient: &http.Client{
Transport: SnowflakeTransport,
Expand All @@ -70,6 +66,23 @@ func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAcce
}), nil
}

func getS3CustomEndpoint(info *execResponseStageInfo) *string {
var endPoint *string
isRegionalURLEnabled := info.UseRegionalURL || info.UseS3RegionalURL
if info.EndPoint != "" {
tmp := fmt.Sprintf("https://%s", info.EndPoint)
endPoint = &tmp
} else if info.Region != "" && isRegionalURLEnabled {
domainSuffixForRegionalURL := "amazonaws.com"
if strings.HasPrefix(strings.ToLower(info.Region), "cn-") {
domainSuffixForRegionalURL = "amazonaws.com.cn"
}
tmp := fmt.Sprintf("https://s3.%s.%s", info.Region, domainSuffixForRegionalURL)
endPoint = &tmp
}
return endPoint
}

func s3LoggingFunc(classification logging.Classification, format string, v ...interface{}) {
switch classification {
case logging.Debug:
Expand Down
99 changes: 99 additions & 0 deletions s3_storage_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -793,3 +793,102 @@ func TestConvertContentLength(t *testing.T) {
})
}
}

func TestGetS3Endpoint(t *testing.T) {
testcases := []struct {
desc string
in execResponseStageInfo
out string
}{

{
desc: "when UseRegionalURL is valid and the region does not start with cn-",
in: execResponseStageInfo{
UseS3RegionalURL: false,
UseRegionalURL: true,
EndPoint: "",
Region: "WEST-1",
},
out: "https://s3.WEST-1.amazonaws.com",
},
{
desc: "when UseS3RegionalURL is valid and the region does not start with cn-",
in: execResponseStageInfo{
UseS3RegionalURL: true,
UseRegionalURL: false,
EndPoint: "",
Region: "WEST-1",
},
out: "https://s3.WEST-1.amazonaws.com",
},
{
desc: "when endPoint is enabled and the region does not start with cn-",
in: execResponseStageInfo{
UseS3RegionalURL: false,
UseRegionalURL: false,
EndPoint: "s3.endpoint",
Region: "mockLocation",
},
out: "https://s3.endpoint",
},
{
desc: "when endPoint is enabled and the region starts with cn-",
in: execResponseStageInfo{
UseS3RegionalURL: false,
UseRegionalURL: false,
EndPoint: "s3.endpoint",
Region: "cn-mockLocation",
},
out: "https://s3.endpoint",
},
{
desc: "when useS3RegionalURL is valid and domain starts with cn",
in: execResponseStageInfo{
UseS3RegionalURL: true,
UseRegionalURL: false,
EndPoint: "",
Region: "cn-mockLocation",
},
out: "https://s3.cn-mockLocation.amazonaws.com.cn",
},
{
desc: "when useRegionalURL is valid and domain starts with cn",
in: execResponseStageInfo{
UseS3RegionalURL: true,
UseRegionalURL: false,
EndPoint: "",
Region: "cn-mockLocation",
},
out: "https://s3.cn-mockLocation.amazonaws.com.cn",
},
{
desc: "when useRegionalURL is valid and domain starts with cn",
in: execResponseStageInfo{
UseS3RegionalURL: true,
UseRegionalURL: false,
EndPoint: "",
Region: "cn-mockLocation",
},
out: "https://s3.cn-mockLocation.amazonaws.com.cn",
},
{
desc: "when endPoint is specified, both UseRegionalURL and useS3PRegionalUrl are valid, and the region starts with cn",
in: execResponseStageInfo{
UseS3RegionalURL: true,
UseRegionalURL: true,
EndPoint: "s3.endpoint",
Region: "cn-mockLocation",
},
out: "https://s3.endpoint",
},
}

for _, test := range testcases {
t.Run(test.desc, func(t *testing.T) {
endpoint := getS3CustomEndpoint(&test.in)
if *endpoint != test.out {
t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out, *endpoint)
}
})
}
}

0 comments on commit 8257f91

Please sign in to comment.