diff --git a/main/cmds.go b/main/cmds.go index 581a377..641dc37 100644 --- a/main/cmds.go +++ b/main/cmds.go @@ -24,7 +24,7 @@ const ( maxScriptSize = 256 * 1024 ) -type cmdFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (msg string, ewc vmextension.ErrorWithClarification) +type cmdFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (msg string, err error) type preFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) error type cmd struct { @@ -57,12 +57,12 @@ var ( } ) -func noop(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmextension.ErrorWithClarification) { +func noop(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) { ctx.Log("event", "noop") - return "", vmextension.NewErrorWithClarification(errorutil.NoError, nil) + return "", nil } -func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmextension.ErrorWithClarification) { +func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) { if err := os.MkdirAll(dataDir, 0755); err != nil { return "", vmextension.NewErrorWithClarification(errorutil.SystemError, errors.Wrap(err, "failed to create data dir")) } @@ -76,10 +76,10 @@ func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmexte ctx.Log("event", "created data dir", "path", dataDir) ctx.Log("event", "installed") - return "", vmextension.NewErrorWithClarification(errorutil.NoError, nil) + return "", nil } -func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmextension.ErrorWithClarification) { +func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) { { // a new context scope with path ctx = ctx.With("path", dataDir) ctx.Log("event", "removing data dir", "path", dataDir) @@ -89,7 +89,7 @@ func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmex ctx.Log("event", "removed data dir") } ctx.Log("event", "uninstalled") - return "", vmextension.NewErrorWithClarification(errorutil.NoError, nil) + return "", nil } func enablePre(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) error { @@ -112,18 +112,18 @@ func min(a, b int) int { return b } -func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmextension.ErrorWithClarification) { +func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) { // parse the extension handler settings (not available prior to 'enable') - cfg, ewc := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum) - if ewc.Err != nil { - ewc.Err = errors.Wrap(ewc.Err, "failed to get configuration") - return "", ewc + cfg, parseErr := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum) + if parseErr != nil { + parseErr = errors.Wrap(parseErr, "failed to get configuration") + return "", parseErr } dir := filepath.Join(dataDir, downloadDir, fmt.Sprintf("%d", seqNum)) - if ewc := downloadFiles(ctx, dir, cfg); ewc.Err != nil { - ewc.Err = errors.Wrap(ewc.Err, "processing file downloads failed") - return "", ewc + if err := downloadFiles(ctx, dir, cfg); err != nil { + err = errors.Wrap(err, "processing file downloads failed") + return "", err } // execute the command, save its error @@ -140,7 +140,7 @@ func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmexten ctx.Log("message", "error tailing stderr logs", "error", err) } - isSuccess := runErr.Err == nil + isSuccess := runErr == nil telemetry("Output", "-- stdout/stderr omitted from telemetry pipeline --", isSuccess, 0) if isSuccess { @@ -179,7 +179,7 @@ func checkAndSaveSeqNum(ctx log.Logger, seq int, mrseqPath string) (shouldExit b // downloadFiles downloads the files specified in cfg into dir (creates if does // not exist) and takes storage credentials specified in cfg into account. -func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) vmextension.ErrorWithClarification { +func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error { // - prepare the output directory for files and the command output // - create the directory if missing ctx.Log("event", "creating output directory", "path", dir) @@ -210,11 +210,11 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) vmextensio } ctx.Log("event", "download complete", "output", dir) } - return vmextension.NewErrorWithClarification(errorutil.NoError, nil) + return nil } // runCmd runs the command (extracted from cfg) in the given dir (assumed to exist). -func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc vmextension.ErrorWithClarification) { +func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc error) { ctx.Log("event", "executing command", "output", dir) var cmd string var scenario string @@ -245,18 +245,24 @@ func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc vmextension.Er } begin := time.Now() - ewc = ExecCmdInDir(cmd, dir) + + executeError := ExecCmdInDir(cmd, dir) elapsed := time.Now().Sub(begin) - isSuccess := ewc.Err == nil + isSuccess := executeError == nil telemetry("scenario", scenario, isSuccess, elapsed) - if ewc.Err != nil { + if executeError != nil { ctx.Log("event", "failed to execute command", "error", err, "output", dir) - return vmextension.NewErrorWithClarification(ewc.ErrorCode, errors.Wrap(ewc.Err, "failed to execute command")) + customErr, ok := executeError.(vmextension.ErrorWithClarification) + if ok { + customErr.Err = errors.Wrap(customErr.Err, "failed to execute command") + return customErr + } + return executeError } ctx.Log("event", "executed command", "output", dir) - return vmextension.NewErrorWithClarification(errorutil.NoError, nil) + return nil } func writeTempScript(script, dir string, skipDosToUnix bool) (string, string, error) { diff --git a/main/cmds_test.go b/main/cmds_test.go index 14c7920..6312116 100644 --- a/main/cmds_test.go +++ b/main/cmds_test.go @@ -7,6 +7,7 @@ import ( "path/filepath" "testing" + vmextension "github.com/Azure/azure-extension-platform/vmextension" "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/ahmetalpbalkan/go-httpbin" "github.com/go-kit/kit/log" @@ -85,7 +86,7 @@ func Test_runCmd_success(t *testing.T) { require.Nil(t, runCmd(log.NewNopLogger(), dir, handlerSettings{ publicSettings: publicSettings{CommandToExecute: "date"}, - }).Err, "command should run successfully") + }), "command should run successfully") // check stdout stderr files _, err = os.Stat(filepath.Join(dir, "stdout")) @@ -99,12 +100,18 @@ func Test_runCmd_fail(t *testing.T) { require.Nil(t, err) defer os.RemoveAll(dir) - ewc := runCmd(log.NewNopLogger(), dir, handlerSettings{ + runErr := runCmd(log.NewNopLogger(), dir, handlerSettings{ publicSettings: publicSettings{CommandToExecute: "non-existing-cmd"}, }) - require.Equal(t, errorutil.CommandExecution_failureExitCode, ewc.ErrorCode) - require.NotNil(t, ewc.Err, "command terminated with exit status") - require.Contains(t, ewc.Err.Error(), "failed to execute command") + customErr, ok := runErr.(vmextension.ErrorWithClarification) + if ok { + require.Equal(t, errorutil.CommandExecution_failureExitCode, customErr.ErrorCode) + require.NotNil(t, customErr, "command terminated with exit status") + require.Contains(t, customErr.Error(), "failed to execute command") + } else { + require.NotNil(t, runErr, "command should have failed") + require.Contains(t, runErr.Error(), "failed to execute command") + } } func Test_downloadFiles(t *testing.T) { @@ -125,7 +132,7 @@ func Test_downloadFiles(t *testing.T) { srv.URL + "/bytes/1000", }}, }) - require.Nil(t, ewc.Err) + require.Nil(t, ewc) // check the files f := []string{"10", "100", "1000"} diff --git a/main/exec.go b/main/exec.go index 96ed557..cc8be93 100644 --- a/main/exec.go +++ b/main/exec.go @@ -18,7 +18,7 @@ import ( // // On error, an exit code may be returned if it is an exit code error. // Given stdout and stderr will be closed upon returning. -func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, vmextension.ErrorWithClarification) { +func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, error) { defer stdout.Close() defer stderr.Close() @@ -36,7 +36,7 @@ func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, vmextension. } } if err == nil { - return 0, vmextension.NewErrorWithClarification(errorutil.NoError, nil) + return 0, nil } return 0, vmextension.NewErrorWithClarification(errorutil.CommandExecution_failedUnknownError, errors.Wrapf(err, "failed to execute command")) @@ -48,7 +48,7 @@ func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, vmextension. // // Ideally, we execute commands only once per sequence number in custom-script-extension, // and save their output under /var/lib/waagent//download//*. -func ExecCmdInDir(cmd, workdir string) vmextension.ErrorWithClarification { +func ExecCmdInDir(cmd, workdir string) error { outFn, errFn := logPaths(workdir) outF, err := os.OpenFile(outFn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) diff --git a/main/exec_test.go b/main/exec_test.go index cb98e5b..6fafa6f 100644 --- a/main/exec_test.go +++ b/main/exec_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "testing" + "github.com/Azure/azure-extension-platform/vmextension" "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/stretchr/testify/require" ) @@ -15,7 +16,7 @@ import ( func TestExec_success(t *testing.T) { v := new(mockFile) ec, err := Exec("date", "/", v, v) - require.Nil(t, err.Err, "err: %v -- out: %s", err.Err, v.b.Bytes()) + require.Nil(t, err, "err: %v -- out: %s", err, v.b.Bytes()) require.EqualValues(t, 0, ec) } @@ -25,7 +26,7 @@ func TestExec_success_redirectsStdStreams_closesFds(t *testing.T) { require.False(t, e.closed, "stderr open") _, err := Exec("/bin/echo 'I am stdout!'>&1; /bin/echo 'I am stderr!'>&2", "/", o, e) - require.Nil(t, err.Err, "err: %v -- stderr: %s", err.Err, e.b.Bytes()) + require.Nil(t, err, "err: %v -- stderr: %s", err, e.b.Bytes()) require.Equal(t, "I am stdout!\n", string(o.b.Bytes())) require.Equal(t, "I am stderr!\n", string(e.b.Bytes())) require.True(t, o.closed, "stdout closed") @@ -34,17 +35,25 @@ func TestExec_success_redirectsStdStreams_closesFds(t *testing.T) { func TestExec_failure_exitError(t *testing.T) { ec, err := Exec("exit 12", "/", new(mockFile), new(mockFile)) - require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failureExitCode) - require.NotNil(t, err.Err) - require.EqualError(t, err.Err, "command terminated with exit status=12") // error is customized + if customErr, ok := err.(vmextension.ErrorWithClarification); ok { + require.Equal(t, customErr.ErrorCode, errorutil.CommandExecution_failureExitCode) + require.NotNil(t, customErr.Err) + require.EqualError(t, customErr.Err, "command terminated with exit status=12") // error is customized + } else { + t.Errorf("error does not have ErrorCode field: %v", err) + } require.EqualValues(t, 12, ec) } func TestExec_failure_genericError(t *testing.T) { _, err := Exec("date", "/non-existing-path", new(mockFile), new(mockFile)) - require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failedUnknownError) - require.NotNil(t, err.Err) - require.Contains(t, err.Err.Error(), "failed to execute command:") // error is wrapped + if customErr, ok := err.(vmextension.ErrorWithClarification); ok { + require.Equal(t, customErr.ErrorCode, errorutil.CommandExecution_failedUnknownError) + } else { + t.Errorf("error does not have ErrorCode field: %v", err) + } + require.NotNil(t, err) + require.Contains(t, err.Error(), "failed to execute command:") // error is wrapped } func TestExec_failure_fdClosed(t *testing.T) { @@ -52,9 +61,13 @@ func TestExec_failure_fdClosed(t *testing.T) { require.Nil(t, out.Close()) _, err := Exec("date", "/", out, out) - require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failedUnknownError) - require.NotNil(t, err.Err) - require.Contains(t, err.Err.Error(), "file closed") // error is wrapped + if customErr, ok := err.(vmextension.ErrorWithClarification); ok { + require.Equal(t, customErr.ErrorCode, errorutil.CommandExecution_failedUnknownError) + } else { + t.Errorf("error does not have ErrorCode field: %v", err) + } + require.NotNil(t, err) + require.Contains(t, err.Error(), "file closed") // error is wrapped } func TestExec_failure_redirectsStdStreams_closesFds(t *testing.T) { @@ -63,8 +76,12 @@ func TestExec_failure_redirectsStdStreams_closesFds(t *testing.T) { require.False(t, e.closed, "stderr open") _, err := Exec(`/bin/echo 'I am stdout!'>&1; /bin/echo 'I am stderr!'>&2; exit 12`, "/", o, e) - require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failureExitCode) - require.NotNil(t, err.Err) + if customErr, ok := err.(vmextension.ErrorWithClarification); ok { + require.Equal(t, customErr.ErrorCode, errorutil.CommandExecution_failureExitCode) + } else { + t.Errorf("error does not have ErrorCode field: %v", err) + } + require.NotNil(t, err) require.Equal(t, "I am stdout!\n", string(o.b.Bytes())) require.Equal(t, "I am stderr!\n", string(e.b.Bytes())) require.True(t, o.closed, "stdout closed") @@ -77,7 +94,7 @@ func TestExecCmdInDir(t *testing.T) { defer os.RemoveAll(dir) ewc := ExecCmdInDir("/bin/echo 'Hello world'", dir) - require.Nil(t, ewc.Err) + require.Nil(t, ewc) require.True(t, fileExists(t, filepath.Join(dir, "stdout")), "stdout file should be created") require.True(t, fileExists(t, filepath.Join(dir, "stderr")), "stderr file should be created") @@ -92,9 +109,13 @@ func TestExecCmdInDir(t *testing.T) { func TestExecCmdInDir_cantOpenError(t *testing.T) { err := ExecCmdInDir("/bin/echo 'Hello world'", "/non-existing-dir") - require.Equal(t, err.ErrorCode, errorutil.NoError) - require.NotNil(t, err.Err) - require.Contains(t, err.Err.Error(), "failed to open stdout file") + require.NotNil(t, err) + if customErr, ok := err.(vmextension.ErrorWithClarification); ok { + require.Equal(t, customErr.ErrorCode, errorutil.NoError) + } else { + t.Errorf("error does not have ErrorCode field: %v", err) + } + require.Contains(t, err.Error(), "failed to open stdout file") } func TestExecCmdInDir_truncates(t *testing.T) { @@ -102,8 +123,8 @@ func TestExecCmdInDir_truncates(t *testing.T) { require.Nil(t, err) defer os.RemoveAll(dir) - require.Nil(t, ExecCmdInDir("/bin/echo '1:out'; /bin/echo '1:err'>&2", dir).Err) - require.Nil(t, ExecCmdInDir("/bin/echo '2:out'; /bin/echo '2:err'>&2", dir).Err) + require.Nil(t, ExecCmdInDir("/bin/echo '1:out'; /bin/echo '1:err'>&2", dir)) + require.Nil(t, ExecCmdInDir("/bin/echo '2:out'; /bin/echo '2:err'>&2", dir)) b, err := ioutil.ReadFile(filepath.Join(dir, "stdout")) require.Nil(t, err) diff --git a/main/handlersettings.go b/main/handlersettings.go index 7899cb6..a248be3 100644 --- a/main/handlersettings.go +++ b/main/handlersettings.go @@ -5,17 +5,17 @@ import ( "fmt" "path/filepath" - "github.com/go-kit/kit/log" - "github.com/pkg/errors" "github.com/Azure/azure-extension-platform/vmextension" "github.com/Azure/custom-script-extension-linux/pkg/errorutil" + "github.com/go-kit/kit/log" + "github.com/pkg/errors" ) var ( errStoragePartialCredentials = errors.New("both 'storageAccountName' and 'storageAccountKey' must be specified") errCmdTooMany = errors.New("'commandToExecute' was specified both in public and protected settings; it must be specified only once") errScriptTooMany = errors.New("'script' was specified both in public and protected settings; it must be specified only once") - errFileUrisTooMany = errors.New("'fileUris' were specified both in public and protected settings; it must be specified only once") + errFileUrisTooMany = errors.New("'fileUris' were specified both in public and protected settings; it must be specified only once") errCmdAndScript = errors.New("'commandToExecute' and 'script' were both specified, but only one is validate at a time") errCmdMissing = errors.New("'commandToExecute' is not specified") errUsingBothKeyAndMsi = errors.New("'storageAccountName' or 'storageAccountKey' must not be specified with 'managedServiceIdentity'") @@ -120,7 +120,7 @@ func (self *clientOrObjectId) isEmpty() bool { // parseAndValidateSettings reads configuration from configFolder, decrypts it, // runs JSON-schema and logical validation on it and returns it back. -func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int) (h handlerSettings, _ vmextension.ErrorWithClarification) { +func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int) (h handlerSettings, _ error) { ctx.Log("event", "reading configuration") pubJSON, protJSON, err := readSettings(configFolder, seqNum) if err != nil { @@ -130,13 +130,13 @@ func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int) ctx.Log("event", "validating json schema") if err := validateSettingsSchema(pubJSON, protJSON); err != nil { - return h, vmextension.NewErrorWithClarification(errorutil.Internal_badConfig, errors.Wrap(err, "json validation error")) + return h, vmextension.NewErrorWithClarification(errorutil.Internal_badConfig, errors.Wrap(err, "json validation error")) } ctx.Log("event", "json schema valid") ctx.Log("event", "parsing configuration json") if err := UnmarshalHandlerSettings(pubJSON, protJSON, &h.publicSettings, &h.protectedSettings); err != nil { - return h, vmextension.NewErrorWithClarification(errorutil.Internal_badConfig, errors.Wrap(err, "json parsing error")) + return h, vmextension.NewErrorWithClarification(errorutil.Internal_badConfig, errors.Wrap(err, "json parsing error")) } ctx.Log("event", "parsed configuration json") @@ -146,7 +146,7 @@ func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int) return h, ewc } ctx.Log("event", "validated configuration") - return h, vmextension.NewErrorWithClarification(errorutil.NoError, nil) + return h, nil } // readSettings uses specified configFolder (comes from HandlerEnvironment) to diff --git a/main/main.go b/main/main.go index 990e5ad..2c43042 100644 --- a/main/main.go +++ b/main/main.go @@ -6,6 +6,8 @@ import ( "strconv" "strings" + "github.com/Azure/azure-extension-platform/vmextension" + "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/go-kit/kit/log" "github.com/pkg/errors" ) @@ -79,11 +81,16 @@ func main() { } // execute the subcommand reportStatus(ctx, hEnv, seqNum, StatusTransitioning, cmd, "") - msg, ewc := cmd.f(ctx, hEnv, seqNum) - if ewc.Err != nil { - ctx.Log("event", "failed to handle", "error", ewc.Error()) - ewc.Err = errors.Wrap(ewc.Err, ewc.Error()+msg) - reportErrorStatus(ctx, hEnv, seqNum, StatusError, cmd, ewc) + msg, err := cmd.f(ctx, hEnv, seqNum) + if err != nil { + ctx.Log("event", "failed to handle", "error", err) + err = errors.Wrap(err, err.Error()+msg) + ewc, ok := err.(vmextension.ErrorWithClarification) + if ok { + reportErrorStatus(ctx, hEnv, seqNum, StatusError, cmd, ewc) + } else { + reportErrorStatus(ctx, hEnv, seqNum, StatusError, cmd, vmextension.NewErrorWithClarification(errorutil.NoError, err)) + } os.Exit(cmd.failExitCode) } reportStatus(ctx, hEnv, seqNum, StatusSuccess, cmd, msg)