From c77da5dd47192d209bbc45284756bd0ecc9bbccc Mon Sep 17 00:00:00 2001 From: Martin Grogan Date: Wed, 18 Dec 2024 17:10:36 -0500 Subject: [PATCH] commonsteps: StepDownload timeout override Should now be overridable via the environment variable PACKER_GETTER_READ_TIMEOUT --- multistep/commonsteps/step_download.go | 81 +++++++++++++-------- multistep/commonsteps/step_download_test.go | 44 +++++++++++ 2 files changed, 93 insertions(+), 32 deletions(-) diff --git a/multistep/commonsteps/step_download.go b/multistep/commonsteps/step_download.go index 8daa8ab40..f482f0196 100644 --- a/multistep/commonsteps/step_download.go +++ b/multistep/commonsteps/step_download.go @@ -64,40 +64,57 @@ type StepDownload struct { // The timeout must be long enough to accommodate large/slow downloads. const defaultGetterReadTimeout time.Duration = 30 * time.Minute -var defaultGetterClient = getter.Client{ - // Disable writing and reading through symlinks. - DisableSymlinks: true, - // The order of the Getters in the list may affect the result - // depending if the Request.Src is detected as valid by multiple getters - Getters: []getter.Getter{ - &getter.GitGetter{ - Timeout: defaultGetterReadTimeout, - Detectors: []getter.Detector{ - new(getter.GitHubDetector), - new(getter.GitDetector), - new(getter.BitBucketDetector), - new(getter.GitLabDetector), +var getterReadTimeout = defaultGetterReadTimeout +var defaultGetterClient = getter.Client{} + +func init() { + prepareGetterClient() +} + +func prepareGetterClient() { + getterReadTimeout := defaultGetterReadTimeout + if env, exists := os.LookupEnv("PACKER_GETTER_READ_TIMEOUT"); exists { + parsedDuration, err := time.ParseDuration(env) + if err != nil { + panic(err) + } + getterReadTimeout = parsedDuration + } + defaultGetterClient = getter.Client{ + // Disable writing and reading through symlinks. + DisableSymlinks: true, + // The order of the Getters in the list may affect the result + // depending if the Request.Src is detected as valid by multiple getters + Getters: []getter.Getter{ + &getter.GitGetter{ + Timeout: getterReadTimeout, + Detectors: []getter.Detector{ + new(getter.GitHubDetector), + new(getter.GitDetector), + new(getter.BitBucketDetector), + new(getter.GitLabDetector), + }, + }, + &getter.HgGetter{ + Timeout: getterReadTimeout, + }, + new(getter.SmbClientGetter), + new(getter.SmbMountGetter), + &getter.HttpGetter{ + Netrc: true, + XTerraformGetDisabled: true, + HeadFirstTimeout: getterReadTimeout, + ReadTimeout: getterReadTimeout, + }, + new(getter.FileGetter), + &gcs.Getter{ + Timeout: getterReadTimeout, + }, + &s3.Getter{ + Timeout: getterReadTimeout, }, }, - &getter.HgGetter{ - Timeout: defaultGetterReadTimeout, - }, - new(getter.SmbClientGetter), - new(getter.SmbMountGetter), - &getter.HttpGetter{ - Netrc: true, - XTerraformGetDisabled: true, - HeadFirstTimeout: defaultGetterReadTimeout, - ReadTimeout: defaultGetterReadTimeout, - }, - new(getter.FileGetter), - &gcs.Getter{ - Timeout: defaultGetterReadTimeout, - }, - &s3.Getter{ - Timeout: defaultGetterReadTimeout, - }, - }, + } } func (s *StepDownload) Run(ctx context.Context, state multistep.StateBag) multistep.StepAction { diff --git a/multistep/commonsteps/step_download_test.go b/multistep/commonsteps/step_download_test.go index 2534e4f63..6eab3a91f 100644 --- a/multistep/commonsteps/step_download_test.go +++ b/multistep/commonsteps/step_download_test.go @@ -15,7 +15,9 @@ import ( "path/filepath" "reflect" "runtime" + "strings" "testing" + "time" "github.com/google/go-cmp/cmp" urlhelper "github.com/hashicorp/go-getter/v2/helper/url" @@ -45,6 +47,7 @@ func abs(t *testing.T, path string) string { func TestStepDownload_Run(t *testing.T) { srvr := httptest.NewServer(http.FileServer(http.Dir("test-fixtures"))) + defer srvr.Close() cs := map[string]string{ @@ -287,6 +290,47 @@ func TestStepDownload_download(t *testing.T) { os.RemoveAll(step.TargetPath) } +func TestStepDownload_short_timeout(t *testing.T) { + + os.Setenv("PACKER_GETTER_READ_TIMEOUT", "1ns") + prepareGetterClient() + defer os.Unsetenv("PACKER_GETTER_READ_TIMEOUT") + defer prepareGetterClient() + + if runtime.GOOS == "windows" { + t.Log("skipping download test on windows right now.") + return + } + srvr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-time.After(time.Microsecond * 100) + w.Write([]byte("should not receive this")) + })) + + defer srvr.Close() + step := &StepDownload{ + Checksum: "sha1:f572d396fae9206628714fb2ce00f72e94f2258f", + Description: "ISO", + ResultKey: "iso_path", + Url: nil, + } + ui := &packersdk.BasicUi{ + Reader: new(bytes.Buffer), + Writer: new(bytes.Buffer), + PB: &packersdk.NoopProgressTracker{}, + } + + _, err := step.download(context.TODO(), ui, srvr.URL+"/root/another.txt?") + + contextDeadlineErrorMsg := "context deadline exceeded" + if err == nil { + t.Fatalf("Bad: expected a '%s' error", contextDeadlineErrorMsg) + } + if !strings.Contains(err.Error(), contextDeadlineErrorMsg) { + t.Fatalf("Bad: expected a '%s' error, got %s", contextDeadlineErrorMsg, err.Error()) + } + +} + func TestStepDownload_WindowsParseSourceURL(t *testing.T) { if runtime.GOOS != "windows" { t.Skip("skip windows specific tests")