diff --git a/main/cmds.go b/main/cmds.go
index 581a377..641dc37 100644
--- a/main/cmds.go
+++ b/main/cmds.go
@@ -24,7 +24,7 @@ const (
maxScriptSize = 256 * 1024
)
-type cmdFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (msg string, ewc vmextension.ErrorWithClarification)
+type cmdFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (msg string, err error)
type preFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) error
type cmd struct {
@@ -57,12 +57,12 @@ var (
}
)
-func noop(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmextension.ErrorWithClarification) {
+func noop(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
ctx.Log("event", "noop")
- return "", vmextension.NewErrorWithClarification(errorutil.NoError, nil)
+ return "", nil
}
-func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmextension.ErrorWithClarification) {
+func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
if err := os.MkdirAll(dataDir, 0755); err != nil {
return "", vmextension.NewErrorWithClarification(errorutil.SystemError, errors.Wrap(err, "failed to create data dir"))
}
@@ -76,10 +76,10 @@ func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmexte
ctx.Log("event", "created data dir", "path", dataDir)
ctx.Log("event", "installed")
- return "", vmextension.NewErrorWithClarification(errorutil.NoError, nil)
+ return "", nil
}
-func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmextension.ErrorWithClarification) {
+func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
{ // a new context scope with path
ctx = ctx.With("path", dataDir)
ctx.Log("event", "removing data dir", "path", dataDir)
@@ -89,7 +89,7 @@ func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmex
ctx.Log("event", "removed data dir")
}
ctx.Log("event", "uninstalled")
- return "", vmextension.NewErrorWithClarification(errorutil.NoError, nil)
+ return "", nil
}
func enablePre(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) error {
@@ -112,18 +112,18 @@ func min(a, b int) int {
return b
}
-func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmextension.ErrorWithClarification) {
+func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
// parse the extension handler settings (not available prior to 'enable')
- cfg, ewc := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum)
- if ewc.Err != nil {
- ewc.Err = errors.Wrap(ewc.Err, "failed to get configuration")
- return "", ewc
+ cfg, parseErr := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum)
+ if parseErr != nil {
+ parseErr = errors.Wrap(parseErr, "failed to get configuration")
+ return "", parseErr
}
dir := filepath.Join(dataDir, downloadDir, fmt.Sprintf("%d", seqNum))
- if ewc := downloadFiles(ctx, dir, cfg); ewc.Err != nil {
- ewc.Err = errors.Wrap(ewc.Err, "processing file downloads failed")
- return "", ewc
+ if err := downloadFiles(ctx, dir, cfg); err != nil {
+ err = errors.Wrap(err, "processing file downloads failed")
+ return "", err
}
// execute the command, save its error
@@ -140,7 +140,7 @@ func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmexten
ctx.Log("message", "error tailing stderr logs", "error", err)
}
- isSuccess := runErr.Err == nil
+ isSuccess := runErr == nil
telemetry("Output", "-- stdout/stderr omitted from telemetry pipeline --", isSuccess, 0)
if isSuccess {
@@ -179,7 +179,7 @@ func checkAndSaveSeqNum(ctx log.Logger, seq int, mrseqPath string) (shouldExit b
// downloadFiles downloads the files specified in cfg into dir (creates if does
// not exist) and takes storage credentials specified in cfg into account.
-func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) vmextension.ErrorWithClarification {
+func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error {
// - prepare the output directory for files and the command output
// - create the directory if missing
ctx.Log("event", "creating output directory", "path", dir)
@@ -210,11 +210,11 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) vmextensio
}
ctx.Log("event", "download complete", "output", dir)
}
- return vmextension.NewErrorWithClarification(errorutil.NoError, nil)
+ return nil
}
// runCmd runs the command (extracted from cfg) in the given dir (assumed to exist).
-func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc vmextension.ErrorWithClarification) {
+func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc error) {
ctx.Log("event", "executing command", "output", dir)
var cmd string
var scenario string
@@ -245,18 +245,24 @@ func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc vmextension.Er
}
begin := time.Now()
- ewc = ExecCmdInDir(cmd, dir)
+
+ executeError := ExecCmdInDir(cmd, dir)
elapsed := time.Now().Sub(begin)
- isSuccess := ewc.Err == nil
+ isSuccess := executeError == nil
telemetry("scenario", scenario, isSuccess, elapsed)
- if ewc.Err != nil {
+ if executeError != nil {
ctx.Log("event", "failed to execute command", "error", err, "output", dir)
- return vmextension.NewErrorWithClarification(ewc.ErrorCode, errors.Wrap(ewc.Err, "failed to execute command"))
+ customErr, ok := executeError.(vmextension.ErrorWithClarification)
+ if ok {
+ customErr.Err = errors.Wrap(customErr.Err, "failed to execute command")
+ return customErr
+ }
+ return executeError
}
ctx.Log("event", "executed command", "output", dir)
- return vmextension.NewErrorWithClarification(errorutil.NoError, nil)
+ return nil
}
func writeTempScript(script, dir string, skipDosToUnix bool) (string, string, error) {
diff --git a/main/cmds_test.go b/main/cmds_test.go
index 14c7920..6312116 100644
--- a/main/cmds_test.go
+++ b/main/cmds_test.go
@@ -7,6 +7,7 @@ import (
"path/filepath"
"testing"
+ vmextension "github.com/Azure/azure-extension-platform/vmextension"
"github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/ahmetalpbalkan/go-httpbin"
"github.com/go-kit/kit/log"
@@ -85,7 +86,7 @@ func Test_runCmd_success(t *testing.T) {
require.Nil(t, runCmd(log.NewNopLogger(), dir, handlerSettings{
publicSettings: publicSettings{CommandToExecute: "date"},
- }).Err, "command should run successfully")
+ }), "command should run successfully")
// check stdout stderr files
_, err = os.Stat(filepath.Join(dir, "stdout"))
@@ -99,12 +100,18 @@ func Test_runCmd_fail(t *testing.T) {
require.Nil(t, err)
defer os.RemoveAll(dir)
- ewc := runCmd(log.NewNopLogger(), dir, handlerSettings{
+ runErr := runCmd(log.NewNopLogger(), dir, handlerSettings{
publicSettings: publicSettings{CommandToExecute: "non-existing-cmd"},
})
- require.Equal(t, errorutil.CommandExecution_failureExitCode, ewc.ErrorCode)
- require.NotNil(t, ewc.Err, "command terminated with exit status")
- require.Contains(t, ewc.Err.Error(), "failed to execute command")
+ customErr, ok := runErr.(vmextension.ErrorWithClarification)
+ if ok {
+ require.Equal(t, errorutil.CommandExecution_failureExitCode, customErr.ErrorCode)
+ require.NotNil(t, customErr, "command terminated with exit status")
+ require.Contains(t, customErr.Error(), "failed to execute command")
+ } else {
+ require.NotNil(t, runErr, "command should have failed")
+ require.Contains(t, runErr.Error(), "failed to execute command")
+ }
}
func Test_downloadFiles(t *testing.T) {
@@ -125,7 +132,7 @@ func Test_downloadFiles(t *testing.T) {
srv.URL + "/bytes/1000",
}},
})
- require.Nil(t, ewc.Err)
+ require.Nil(t, ewc)
// check the files
f := []string{"10", "100", "1000"}
diff --git a/main/exec.go b/main/exec.go
index 96ed557..cc8be93 100644
--- a/main/exec.go
+++ b/main/exec.go
@@ -18,7 +18,7 @@ import (
//
// On error, an exit code may be returned if it is an exit code error.
// Given stdout and stderr will be closed upon returning.
-func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, vmextension.ErrorWithClarification) {
+func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, error) {
defer stdout.Close()
defer stderr.Close()
@@ -36,7 +36,7 @@ func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, vmextension.
}
}
if err == nil {
- return 0, vmextension.NewErrorWithClarification(errorutil.NoError, nil)
+ return 0, nil
}
return 0, vmextension.NewErrorWithClarification(errorutil.CommandExecution_failedUnknownError, errors.Wrapf(err, "failed to execute command"))
@@ -48,7 +48,7 @@ func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, vmextension.
//
// Ideally, we execute commands only once per sequence number in custom-script-extension,
// and save their output under /var/lib/waagent/
/download//*.
-func ExecCmdInDir(cmd, workdir string) vmextension.ErrorWithClarification {
+func ExecCmdInDir(cmd, workdir string) error {
outFn, errFn := logPaths(workdir)
outF, err := os.OpenFile(outFn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
diff --git a/main/exec_test.go b/main/exec_test.go
index cb98e5b..6fafa6f 100644
--- a/main/exec_test.go
+++ b/main/exec_test.go
@@ -8,6 +8,7 @@ import (
"path/filepath"
"testing"
+ "github.com/Azure/azure-extension-platform/vmextension"
"github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/stretchr/testify/require"
)
@@ -15,7 +16,7 @@ import (
func TestExec_success(t *testing.T) {
v := new(mockFile)
ec, err := Exec("date", "/", v, v)
- require.Nil(t, err.Err, "err: %v -- out: %s", err.Err, v.b.Bytes())
+ require.Nil(t, err, "err: %v -- out: %s", err, v.b.Bytes())
require.EqualValues(t, 0, ec)
}
@@ -25,7 +26,7 @@ func TestExec_success_redirectsStdStreams_closesFds(t *testing.T) {
require.False(t, e.closed, "stderr open")
_, err := Exec("/bin/echo 'I am stdout!'>&1; /bin/echo 'I am stderr!'>&2", "/", o, e)
- require.Nil(t, err.Err, "err: %v -- stderr: %s", err.Err, e.b.Bytes())
+ require.Nil(t, err, "err: %v -- stderr: %s", err, e.b.Bytes())
require.Equal(t, "I am stdout!\n", string(o.b.Bytes()))
require.Equal(t, "I am stderr!\n", string(e.b.Bytes()))
require.True(t, o.closed, "stdout closed")
@@ -34,17 +35,25 @@ func TestExec_success_redirectsStdStreams_closesFds(t *testing.T) {
func TestExec_failure_exitError(t *testing.T) {
ec, err := Exec("exit 12", "/", new(mockFile), new(mockFile))
- require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failureExitCode)
- require.NotNil(t, err.Err)
- require.EqualError(t, err.Err, "command terminated with exit status=12") // error is customized
+ if customErr, ok := err.(vmextension.ErrorWithClarification); ok {
+ require.Equal(t, customErr.ErrorCode, errorutil.CommandExecution_failureExitCode)
+ require.NotNil(t, customErr.Err)
+ require.EqualError(t, customErr.Err, "command terminated with exit status=12") // error is customized
+ } else {
+ t.Errorf("error does not have ErrorCode field: %v", err)
+ }
require.EqualValues(t, 12, ec)
}
func TestExec_failure_genericError(t *testing.T) {
_, err := Exec("date", "/non-existing-path", new(mockFile), new(mockFile))
- require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failedUnknownError)
- require.NotNil(t, err.Err)
- require.Contains(t, err.Err.Error(), "failed to execute command:") // error is wrapped
+ if customErr, ok := err.(vmextension.ErrorWithClarification); ok {
+ require.Equal(t, customErr.ErrorCode, errorutil.CommandExecution_failedUnknownError)
+ } else {
+ t.Errorf("error does not have ErrorCode field: %v", err)
+ }
+ require.NotNil(t, err)
+ require.Contains(t, err.Error(), "failed to execute command:") // error is wrapped
}
func TestExec_failure_fdClosed(t *testing.T) {
@@ -52,9 +61,13 @@ func TestExec_failure_fdClosed(t *testing.T) {
require.Nil(t, out.Close())
_, err := Exec("date", "/", out, out)
- require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failedUnknownError)
- require.NotNil(t, err.Err)
- require.Contains(t, err.Err.Error(), "file closed") // error is wrapped
+ if customErr, ok := err.(vmextension.ErrorWithClarification); ok {
+ require.Equal(t, customErr.ErrorCode, errorutil.CommandExecution_failedUnknownError)
+ } else {
+ t.Errorf("error does not have ErrorCode field: %v", err)
+ }
+ require.NotNil(t, err)
+ require.Contains(t, err.Error(), "file closed") // error is wrapped
}
func TestExec_failure_redirectsStdStreams_closesFds(t *testing.T) {
@@ -63,8 +76,12 @@ func TestExec_failure_redirectsStdStreams_closesFds(t *testing.T) {
require.False(t, e.closed, "stderr open")
_, err := Exec(`/bin/echo 'I am stdout!'>&1; /bin/echo 'I am stderr!'>&2; exit 12`, "/", o, e)
- require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failureExitCode)
- require.NotNil(t, err.Err)
+ if customErr, ok := err.(vmextension.ErrorWithClarification); ok {
+ require.Equal(t, customErr.ErrorCode, errorutil.CommandExecution_failureExitCode)
+ } else {
+ t.Errorf("error does not have ErrorCode field: %v", err)
+ }
+ require.NotNil(t, err)
require.Equal(t, "I am stdout!\n", string(o.b.Bytes()))
require.Equal(t, "I am stderr!\n", string(e.b.Bytes()))
require.True(t, o.closed, "stdout closed")
@@ -77,7 +94,7 @@ func TestExecCmdInDir(t *testing.T) {
defer os.RemoveAll(dir)
ewc := ExecCmdInDir("/bin/echo 'Hello world'", dir)
- require.Nil(t, ewc.Err)
+ require.Nil(t, ewc)
require.True(t, fileExists(t, filepath.Join(dir, "stdout")), "stdout file should be created")
require.True(t, fileExists(t, filepath.Join(dir, "stderr")), "stderr file should be created")
@@ -92,9 +109,13 @@ func TestExecCmdInDir(t *testing.T) {
func TestExecCmdInDir_cantOpenError(t *testing.T) {
err := ExecCmdInDir("/bin/echo 'Hello world'", "/non-existing-dir")
- require.Equal(t, err.ErrorCode, errorutil.NoError)
- require.NotNil(t, err.Err)
- require.Contains(t, err.Err.Error(), "failed to open stdout file")
+ require.NotNil(t, err)
+ if customErr, ok := err.(vmextension.ErrorWithClarification); ok {
+ require.Equal(t, customErr.ErrorCode, errorutil.NoError)
+ } else {
+ t.Errorf("error does not have ErrorCode field: %v", err)
+ }
+ require.Contains(t, err.Error(), "failed to open stdout file")
}
func TestExecCmdInDir_truncates(t *testing.T) {
@@ -102,8 +123,8 @@ func TestExecCmdInDir_truncates(t *testing.T) {
require.Nil(t, err)
defer os.RemoveAll(dir)
- require.Nil(t, ExecCmdInDir("/bin/echo '1:out'; /bin/echo '1:err'>&2", dir).Err)
- require.Nil(t, ExecCmdInDir("/bin/echo '2:out'; /bin/echo '2:err'>&2", dir).Err)
+ require.Nil(t, ExecCmdInDir("/bin/echo '1:out'; /bin/echo '1:err'>&2", dir))
+ require.Nil(t, ExecCmdInDir("/bin/echo '2:out'; /bin/echo '2:err'>&2", dir))
b, err := ioutil.ReadFile(filepath.Join(dir, "stdout"))
require.Nil(t, err)
diff --git a/main/handlersettings.go b/main/handlersettings.go
index 7899cb6..a248be3 100644
--- a/main/handlersettings.go
+++ b/main/handlersettings.go
@@ -5,17 +5,17 @@ import (
"fmt"
"path/filepath"
- "github.com/go-kit/kit/log"
- "github.com/pkg/errors"
"github.com/Azure/azure-extension-platform/vmextension"
"github.com/Azure/custom-script-extension-linux/pkg/errorutil"
+ "github.com/go-kit/kit/log"
+ "github.com/pkg/errors"
)
var (
errStoragePartialCredentials = errors.New("both 'storageAccountName' and 'storageAccountKey' must be specified")
errCmdTooMany = errors.New("'commandToExecute' was specified both in public and protected settings; it must be specified only once")
errScriptTooMany = errors.New("'script' was specified both in public and protected settings; it must be specified only once")
- errFileUrisTooMany = errors.New("'fileUris' were specified both in public and protected settings; it must be specified only once")
+ errFileUrisTooMany = errors.New("'fileUris' were specified both in public and protected settings; it must be specified only once")
errCmdAndScript = errors.New("'commandToExecute' and 'script' were both specified, but only one is validate at a time")
errCmdMissing = errors.New("'commandToExecute' is not specified")
errUsingBothKeyAndMsi = errors.New("'storageAccountName' or 'storageAccountKey' must not be specified with 'managedServiceIdentity'")
@@ -120,7 +120,7 @@ func (self *clientOrObjectId) isEmpty() bool {
// parseAndValidateSettings reads configuration from configFolder, decrypts it,
// runs JSON-schema and logical validation on it and returns it back.
-func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int) (h handlerSettings, _ vmextension.ErrorWithClarification) {
+func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int) (h handlerSettings, _ error) {
ctx.Log("event", "reading configuration")
pubJSON, protJSON, err := readSettings(configFolder, seqNum)
if err != nil {
@@ -130,13 +130,13 @@ func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int)
ctx.Log("event", "validating json schema")
if err := validateSettingsSchema(pubJSON, protJSON); err != nil {
- return h, vmextension.NewErrorWithClarification(errorutil.Internal_badConfig, errors.Wrap(err, "json validation error"))
+ return h, vmextension.NewErrorWithClarification(errorutil.Internal_badConfig, errors.Wrap(err, "json validation error"))
}
ctx.Log("event", "json schema valid")
ctx.Log("event", "parsing configuration json")
if err := UnmarshalHandlerSettings(pubJSON, protJSON, &h.publicSettings, &h.protectedSettings); err != nil {
- return h, vmextension.NewErrorWithClarification(errorutil.Internal_badConfig, errors.Wrap(err, "json parsing error"))
+ return h, vmextension.NewErrorWithClarification(errorutil.Internal_badConfig, errors.Wrap(err, "json parsing error"))
}
ctx.Log("event", "parsed configuration json")
@@ -146,7 +146,7 @@ func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int)
return h, ewc
}
ctx.Log("event", "validated configuration")
- return h, vmextension.NewErrorWithClarification(errorutil.NoError, nil)
+ return h, nil
}
// readSettings uses specified configFolder (comes from HandlerEnvironment) to
diff --git a/main/main.go b/main/main.go
index 990e5ad..2c43042 100644
--- a/main/main.go
+++ b/main/main.go
@@ -6,6 +6,8 @@ import (
"strconv"
"strings"
+ "github.com/Azure/azure-extension-platform/vmextension"
+ "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/go-kit/kit/log"
"github.com/pkg/errors"
)
@@ -79,11 +81,16 @@ func main() {
}
// execute the subcommand
reportStatus(ctx, hEnv, seqNum, StatusTransitioning, cmd, "")
- msg, ewc := cmd.f(ctx, hEnv, seqNum)
- if ewc.Err != nil {
- ctx.Log("event", "failed to handle", "error", ewc.Error())
- ewc.Err = errors.Wrap(ewc.Err, ewc.Error()+msg)
- reportErrorStatus(ctx, hEnv, seqNum, StatusError, cmd, ewc)
+ msg, err := cmd.f(ctx, hEnv, seqNum)
+ if err != nil {
+ ctx.Log("event", "failed to handle", "error", err)
+ err = errors.Wrap(err, err.Error()+msg)
+ ewc, ok := err.(vmextension.ErrorWithClarification)
+ if ok {
+ reportErrorStatus(ctx, hEnv, seqNum, StatusError, cmd, ewc)
+ } else {
+ reportErrorStatus(ctx, hEnv, seqNum, StatusError, cmd, vmextension.NewErrorWithClarification(errorutil.NoError, err))
+ }
os.Exit(cmd.failExitCode)
}
reportStatus(ctx, hEnv, seqNum, StatusSuccess, cmd, msg)