Skip to content

Commit

Permalink
fix test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-ext-simba-jl committed Dec 10, 2024
1 parent d430abb commit d44d533
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
18 changes: 9 additions & 9 deletions gcs_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,24 +414,24 @@ func newGcsClient(info *execResponseStageInfo) (gcsAPI, error) {

// TODO: SNOW-1789759 hardcoded region will be replaced in the future
endpoint := getGcsCustomEndpoint(info)

_, err := storage.NewClient(context.Background(), option.WithHTTPClient(httpClient), option.WithEndpoint(endpoint))
_, err := storage.NewClient(context.Background(), option.WithHTTPClient(httpClient))
if endpoint != "" {
_, err = storage.NewClient(context.Background(), option.WithHTTPClient(httpClient), option.WithEndpoint(endpoint))
}
if err != nil {
return nil, err
}

return httpClient, nil
}

func getGcsCustomEndpoint(info *execResponseStageInfo) string {
// TODO: SNOW-1789759 hardcoded region will be replaced in the future
var endpoint string
isRegionalURLEnabled := (strings.ToLower(info.Region) == gcsRegionMeCentral2) || info.UseRegionalURL
if info.EndPoint != "" {
endpoint = fmt.Sprintf("https://%s", info.EndPoint)
} else {
if info.Region != "" && isRegionalURLEnabled {
endpoint = fmt.Sprintf("https://storage.%s.rep.googleapis.com", strings.ToLower(info.Region))
}
return fmt.Sprintf("https://%s", info.EndPoint)
} else if info.Region != "" && isRegionalURLEnabled {
return fmt.Sprintf("https://storage.%s.rep.googleapis.com", strings.ToLower(info.Region))
}
return endpoint
return ""
}
16 changes: 7 additions & 9 deletions s3_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,13 @@ func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAcce
s3Logger := logging.LoggerFunc(s3LoggingFunc)

endpoint := getS3CustomEndpoint(info)
if endpoint == "" {
return nil, fmt.Errorf("error when retrieving endpoint")
}

return s3.New(s3.Options{
Region: info.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(
stageCredentials.AwsKeyID,
stageCredentials.AwsSecretKey,
stageCredentials.AwsToken)),
BaseEndpoint: &endpoint,
BaseEndpoint: endpoint,
UseAccelerate: useAccelerateEndpoint,
HTTPClient: &http.Client{
Transport: SnowflakeTransport,
Expand All @@ -70,19 +66,21 @@ func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAcce
}), nil
}

func getS3CustomEndpoint(info *execResponseStageInfo) string {
func getS3CustomEndpoint(info *execResponseStageInfo) *string {
if info.EndPoint != "" {
return fmt.Sprintf("https://%s", info.EndPoint)
endpoint := fmt.Sprintf("https://%s", info.EndPoint)
return &endpoint
}
isRegionalURLEnabled := info.UseRegionalURL || info.UseS3RegionalURL
if info.Region != "" && isRegionalURLEnabled {
domainSuffixForRegionalURL := "amazonaws.com"
if strings.HasPrefix(strings.ToLower(info.Region), "cn-") {
domainSuffixForRegionalURL = "amazonaws.com.cn"
}
return fmt.Sprintf("https://s3.%s.%s", info.Region, domainSuffixForRegionalURL)
endpoint := fmt.Sprintf("https://s3.%s.%s", info.Region, domainSuffixForRegionalURL)
return &endpoint
}
return ""
return nil
}

func s3LoggingFunc(classification logging.Classification, format string, v ...interface{}) {
Expand Down
5 changes: 4 additions & 1 deletion s3_storage_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,10 @@ func TestGetS3Endpoint(t *testing.T) {
for _, test := range testcases {
t.Run(test.desc, func(t *testing.T) {
endpoint := getS3CustomEndpoint(&test.in)
if endpoint != test.out {
if endpoint == nil && test.out != "" {
t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, nil, endpoint)
}
if endpoint != nil && *endpoint != test.out {
t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out, endpoint)
}
})
Expand Down

0 comments on commit d44d533

Please sign in to comment.