Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions main/cmds.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (
maxScriptSize = 256 * 1024
)

type cmdFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (msg string, ewc vmextension.ErrorWithClarification)
type cmdFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (msg string, err error)
type preFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) error

type cmd struct {
Expand Down Expand Up @@ -57,12 +57,12 @@ var (
}
)

func noop(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmextension.ErrorWithClarification) {
func noop(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
ctx.Log("event", "noop")
return "", vmextension.NewErrorWithClarification(errorutil.NoError, nil)
return "", nil
}

func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmextension.ErrorWithClarification) {
func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
if err := os.MkdirAll(dataDir, 0755); err != nil {
return "", vmextension.NewErrorWithClarification(errorutil.SystemError, errors.Wrap(err, "failed to create data dir"))
}
Expand All @@ -76,10 +76,10 @@ func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmexte

ctx.Log("event", "created data dir", "path", dataDir)
ctx.Log("event", "installed")
return "", vmextension.NewErrorWithClarification(errorutil.NoError, nil)
return "", nil
}

func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmextension.ErrorWithClarification) {
func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
{ // a new context scope with path
ctx = ctx.With("path", dataDir)
ctx.Log("event", "removing data dir", "path", dataDir)
Expand All @@ -89,7 +89,7 @@ func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmex
ctx.Log("event", "removed data dir")
}
ctx.Log("event", "uninstalled")
return "", vmextension.NewErrorWithClarification(errorutil.NoError, nil)
return "", nil
}

func enablePre(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) error {
Expand All @@ -112,18 +112,18 @@ func min(a, b int) int {
return b
}

func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmextension.ErrorWithClarification) {
func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
// parse the extension handler settings (not available prior to 'enable')
cfg, ewc := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum)
if ewc.Err != nil {
ewc.Err = errors.Wrap(ewc.Err, "failed to get configuration")
if ewc != nil {
ewc = errors.Wrap(ewc, "failed to get configuration")
return "", ewc
}

dir := filepath.Join(dataDir, downloadDir, fmt.Sprintf("%d", seqNum))
if ewc := downloadFiles(ctx, dir, cfg); ewc.Err != nil {
ewc.Err = errors.Wrap(ewc.Err, "processing file downloads failed")
return "", ewc
if err := downloadFiles(ctx, dir, cfg); err != nil {
err = errors.Wrap(err, "processing file downloads failed")
return "", err
}

// execute the command, save its error
Expand All @@ -140,7 +140,7 @@ func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, vmexten
ctx.Log("message", "error tailing stderr logs", "error", err)
}

isSuccess := runErr.Err == nil
isSuccess := runErr == nil
telemetry("Output", "-- stdout/stderr omitted from telemetry pipeline --", isSuccess, 0)

if isSuccess {
Expand Down Expand Up @@ -179,7 +179,7 @@ func checkAndSaveSeqNum(ctx log.Logger, seq int, mrseqPath string) (shouldExit b

// downloadFiles downloads the files specified in cfg into dir (creates if does
// not exist) and takes storage credentials specified in cfg into account.
func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) vmextension.ErrorWithClarification {
func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error {
// - prepare the output directory for files and the command output
// - create the directory if missing
ctx.Log("event", "creating output directory", "path", dir)
Expand Down Expand Up @@ -210,11 +210,11 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) vmextensio
}
ctx.Log("event", "download complete", "output", dir)
}
return vmextension.NewErrorWithClarification(errorutil.NoError, nil)
return nil
}

// runCmd runs the command (extracted from cfg) in the given dir (assumed to exist).
func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc vmextension.ErrorWithClarification) {
func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc error) {
ctx.Log("event", "executing command", "output", dir)
var cmd string
var scenario string
Expand Down Expand Up @@ -245,18 +245,21 @@ func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc vmextension.Er
}

begin := time.Now()
ewc = ExecCmdInDir(cmd, dir)

executeError := ExecCmdInDir(cmd, dir)
elapsed := time.Now().Sub(begin)
isSuccess := ewc.Err == nil
isSuccess := executeError == nil

telemetry("scenario", scenario, isSuccess, elapsed)

if ewc.Err != nil {
if executeError != nil {
ctx.Log("event", "failed to execute command", "error", err, "output", dir)
return vmextension.NewErrorWithClarification(ewc.ErrorCode, errors.Wrap(ewc.Err, "failed to execute command"))
customErr := executeError.(vmextension.ErrorWithClarification)
customErr.Err = errors.Wrap(customErr.Err, "failed to execute command")
return customErr
}
ctx.Log("event", "executed command", "output", dir)
return vmextension.NewErrorWithClarification(errorutil.NoError, nil)
return nil
}

func writeTempScript(script, dir string, skipDosToUnix bool) (string, string, error) {
Expand Down
12 changes: 7 additions & 5 deletions main/cmds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"path/filepath"
"testing"

vmextension "github.com/Azure/azure-extension-platform/vmextension"
"github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/ahmetalpbalkan/go-httpbin"
"github.com/go-kit/kit/log"
Expand Down Expand Up @@ -85,7 +86,7 @@ func Test_runCmd_success(t *testing.T) {

require.Nil(t, runCmd(log.NewNopLogger(), dir, handlerSettings{
publicSettings: publicSettings{CommandToExecute: "date"},
}).Err, "command should run successfully")
}), "command should run successfully")

// check stdout stderr files
_, err = os.Stat(filepath.Join(dir, "stdout"))
Expand All @@ -102,9 +103,10 @@ func Test_runCmd_fail(t *testing.T) {
ewc := runCmd(log.NewNopLogger(), dir, handlerSettings{
publicSettings: publicSettings{CommandToExecute: "non-existing-cmd"},
})
require.Equal(t, errorutil.CommandExecution_failureExitCode, ewc.ErrorCode)
require.NotNil(t, ewc.Err, "command terminated with exit status")
require.Contains(t, ewc.Err.Error(), "failed to execute command")
customErr := ewc.(vmextension.ErrorWithClarification)
require.Equal(t, errorutil.CommandExecution_failureExitCode, customErr.ErrorCode)
require.NotNil(t, customErr, "command terminated with exit status")
require.Contains(t, customErr.Error(), "failed to execute command")
}

func Test_downloadFiles(t *testing.T) {
Expand All @@ -125,7 +127,7 @@ func Test_downloadFiles(t *testing.T) {
srv.URL + "/bytes/1000",
}},
})
require.Nil(t, ewc.Err)
require.Nil(t, ewc)

// check the files
f := []string{"10", "100", "1000"}
Expand Down
6 changes: 3 additions & 3 deletions main/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
//
// On error, an exit code may be returned if it is an exit code error.
// Given stdout and stderr will be closed upon returning.
func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, vmextension.ErrorWithClarification) {
func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, error) {
defer stdout.Close()
defer stderr.Close()

Expand All @@ -36,7 +36,7 @@ func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, vmextension.
}
}
if err == nil {
return 0, vmextension.NewErrorWithClarification(errorutil.NoError, nil)
return 0, nil
}
return 0, vmextension.NewErrorWithClarification(errorutil.CommandExecution_failedUnknownError, errors.Wrapf(err, "failed to execute command"))

Expand All @@ -48,7 +48,7 @@ func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, vmextension.
//
// Ideally, we execute commands only once per sequence number in custom-script-extension,
// and save their output under /var/lib/waagent/<dir>/download/<seqnum>/*.
func ExecCmdInDir(cmd, workdir string) vmextension.ErrorWithClarification {
func ExecCmdInDir(cmd, workdir string) error {
outFn, errFn := logPaths(workdir)

outF, err := os.OpenFile(outFn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
Expand Down
59 changes: 40 additions & 19 deletions main/exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import (
"path/filepath"
"testing"

"github.com/Azure/azure-extension-platform/vmextension"
"github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/stretchr/testify/require"
)

func TestExec_success(t *testing.T) {
v := new(mockFile)
ec, err := Exec("date", "/", v, v)
require.Nil(t, err.Err, "err: %v -- out: %s", err.Err, v.b.Bytes())
require.Nil(t, err, "err: %v -- out: %s", err, v.b.Bytes())
require.EqualValues(t, 0, ec)
}

Expand All @@ -25,7 +26,7 @@ func TestExec_success_redirectsStdStreams_closesFds(t *testing.T) {
require.False(t, e.closed, "stderr open")

_, err := Exec("/bin/echo 'I am stdout!'>&1; /bin/echo 'I am stderr!'>&2", "/", o, e)
require.Nil(t, err.Err, "err: %v -- stderr: %s", err.Err, e.b.Bytes())
require.Nil(t, err, "err: %v -- stderr: %s", err, e.b.Bytes())
require.Equal(t, "I am stdout!\n", string(o.b.Bytes()))
require.Equal(t, "I am stderr!\n", string(e.b.Bytes()))
require.True(t, o.closed, "stdout closed")
Expand All @@ -34,27 +35,39 @@ func TestExec_success_redirectsStdStreams_closesFds(t *testing.T) {

func TestExec_failure_exitError(t *testing.T) {
ec, err := Exec("exit 12", "/", new(mockFile), new(mockFile))
require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failureExitCode)
require.NotNil(t, err.Err)
require.EqualError(t, err.Err, "command terminated with exit status=12") // error is customized
if customErr, ok := err.(vmextension.ErrorWithClarification); ok {
require.Equal(t, customErr.ErrorCode, errorutil.CommandExecution_failureExitCode)
require.NotNil(t, customErr.Err)
require.EqualError(t, customErr.Err, "command terminated with exit status=12") // error is customized
} else {
t.Errorf("error does not have ErrorCode field: %v", err)
}
require.EqualValues(t, 12, ec)
}

func TestExec_failure_genericError(t *testing.T) {
_, err := Exec("date", "/non-existing-path", new(mockFile), new(mockFile))
require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failedUnknownError)
require.NotNil(t, err.Err)
require.Contains(t, err.Err.Error(), "failed to execute command:") // error is wrapped
if customErr, ok := err.(vmextension.ErrorWithClarification); ok {
require.Equal(t, customErr.ErrorCode, errorutil.CommandExecution_failedUnknownError)
} else {
t.Errorf("error does not have ErrorCode field: %v", err)
}
require.NotNil(t, err)
require.Contains(t, err.Error(), "failed to execute command:") // error is wrapped
}

func TestExec_failure_fdClosed(t *testing.T) {
out := new(mockFile)
require.Nil(t, out.Close())

_, err := Exec("date", "/", out, out)
require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failedUnknownError)
require.NotNil(t, err.Err)
require.Contains(t, err.Err.Error(), "file closed") // error is wrapped
if customErr, ok := err.(vmextension.ErrorWithClarification); ok {
require.Equal(t, customErr.ErrorCode, errorutil.CommandExecution_failedUnknownError)
} else {
t.Errorf("error does not have ErrorCode field: %v", err)
}
require.NotNil(t, err)
require.Contains(t, err.Error(), "file closed") // error is wrapped
}

func TestExec_failure_redirectsStdStreams_closesFds(t *testing.T) {
Expand All @@ -63,8 +76,12 @@ func TestExec_failure_redirectsStdStreams_closesFds(t *testing.T) {
require.False(t, e.closed, "stderr open")

_, err := Exec(`/bin/echo 'I am stdout!'>&1; /bin/echo 'I am stderr!'>&2; exit 12`, "/", o, e)
require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failureExitCode)
require.NotNil(t, err.Err)
if customErr, ok := err.(vmextension.ErrorWithClarification); ok {
require.Equal(t, customErr.ErrorCode, errorutil.CommandExecution_failureExitCode)
} else {
t.Errorf("error does not have ErrorCode field: %v", err)
}
require.NotNil(t, err)
require.Equal(t, "I am stdout!\n", string(o.b.Bytes()))
require.Equal(t, "I am stderr!\n", string(e.b.Bytes()))
require.True(t, o.closed, "stdout closed")
Expand All @@ -77,7 +94,7 @@ func TestExecCmdInDir(t *testing.T) {
defer os.RemoveAll(dir)

ewc := ExecCmdInDir("/bin/echo 'Hello world'", dir)
require.Nil(t, ewc.Err)
require.Nil(t, ewc)
require.True(t, fileExists(t, filepath.Join(dir, "stdout")), "stdout file should be created")
require.True(t, fileExists(t, filepath.Join(dir, "stderr")), "stderr file should be created")

Expand All @@ -92,18 +109,22 @@ func TestExecCmdInDir(t *testing.T) {

func TestExecCmdInDir_cantOpenError(t *testing.T) {
err := ExecCmdInDir("/bin/echo 'Hello world'", "/non-existing-dir")
require.Equal(t, err.ErrorCode, errorutil.NoError)
require.NotNil(t, err.Err)
require.Contains(t, err.Err.Error(), "failed to open stdout file")
require.NotNil(t, err)
if customErr, ok := err.(vmextension.ErrorWithClarification); ok {
require.Equal(t, customErr.ErrorCode, errorutil.NoError)
} else {
t.Errorf("error does not have ErrorCode field: %v", err)
}
require.Contains(t, err.Error(), "failed to open stdout file")
}

func TestExecCmdInDir_truncates(t *testing.T) {
dir, err := ioutil.TempDir("", "")
require.Nil(t, err)
defer os.RemoveAll(dir)

require.Nil(t, ExecCmdInDir("/bin/echo '1:out'; /bin/echo '1:err'>&2", dir).Err)
require.Nil(t, ExecCmdInDir("/bin/echo '2:out'; /bin/echo '2:err'>&2", dir).Err)
require.Nil(t, ExecCmdInDir("/bin/echo '1:out'; /bin/echo '1:err'>&2", dir))
require.Nil(t, ExecCmdInDir("/bin/echo '2:out'; /bin/echo '2:err'>&2", dir))

b, err := ioutil.ReadFile(filepath.Join(dir, "stdout"))
require.Nil(t, err)
Expand Down
14 changes: 7 additions & 7 deletions main/handlersettings.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ import (
"fmt"
"path/filepath"

"github.com/go-kit/kit/log"
"github.com/pkg/errors"
"github.com/Azure/azure-extension-platform/vmextension"
"github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/go-kit/kit/log"
"github.com/pkg/errors"
)

var (
errStoragePartialCredentials = errors.New("both 'storageAccountName' and 'storageAccountKey' must be specified")
errCmdTooMany = errors.New("'commandToExecute' was specified both in public and protected settings; it must be specified only once")
errScriptTooMany = errors.New("'script' was specified both in public and protected settings; it must be specified only once")
errFileUrisTooMany = errors.New("'fileUris' were specified both in public and protected settings; it must be specified only once")
errFileUrisTooMany = errors.New("'fileUris' were specified both in public and protected settings; it must be specified only once")
errCmdAndScript = errors.New("'commandToExecute' and 'script' were both specified, but only one is validate at a time")
errCmdMissing = errors.New("'commandToExecute' is not specified")
errUsingBothKeyAndMsi = errors.New("'storageAccountName' or 'storageAccountKey' must not be specified with 'managedServiceIdentity'")
Expand Down Expand Up @@ -120,7 +120,7 @@ func (self *clientOrObjectId) isEmpty() bool {

// parseAndValidateSettings reads configuration from configFolder, decrypts it,
// runs JSON-schema and logical validation on it and returns it back.
func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int) (h handlerSettings, _ vmextension.ErrorWithClarification) {
func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int) (h handlerSettings, _ error) {
ctx.Log("event", "reading configuration")
pubJSON, protJSON, err := readSettings(configFolder, seqNum)
if err != nil {
Expand All @@ -130,13 +130,13 @@ func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int)

ctx.Log("event", "validating json schema")
if err := validateSettingsSchema(pubJSON, protJSON); err != nil {
return h, vmextension.NewErrorWithClarification(errorutil.Internal_badConfig, errors.Wrap(err, "json validation error"))
return h, vmextension.NewErrorWithClarification(errorutil.Internal_badConfig, errors.Wrap(err, "json validation error"))
}
ctx.Log("event", "json schema valid")

ctx.Log("event", "parsing configuration json")
if err := UnmarshalHandlerSettings(pubJSON, protJSON, &h.publicSettings, &h.protectedSettings); err != nil {
return h, vmextension.NewErrorWithClarification(errorutil.Internal_badConfig, errors.Wrap(err, "json parsing error"))
return h, vmextension.NewErrorWithClarification(errorutil.Internal_badConfig, errors.Wrap(err, "json parsing error"))
}
ctx.Log("event", "parsed configuration json")

Expand All @@ -146,7 +146,7 @@ func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int)
return h, ewc
}
ctx.Log("event", "validated configuration")
return h, vmextension.NewErrorWithClarification(errorutil.NoError, nil)
return h, nil
}

// readSettings uses specified configFolder (comes from HandlerEnvironment) to
Expand Down
Loading