diff --git a/src/datachannel/streaming.go b/src/datachannel/streaming.go index 3ff4c59f..a7cd27ee 100644 --- a/src/datachannel/streaming.go +++ b/src/datachannel/streaming.go @@ -21,8 +21,8 @@ import ( "encoding/json" "errors" "fmt" + "io" "math" - "os" "reflect" "sync" "time" @@ -115,6 +115,9 @@ type DataChannel struct { // AgentVersion received during handshake agentVersion string + + // Out is where user ssm plugin logs go + Out io.Writer } type ListMessageBuffer struct { @@ -510,7 +513,7 @@ func (dataChannel *DataChannel) handleHandshakeComplete(log log.T, clientMessage handshakeComplete.HandshakeTimeToComplete.Seconds()) if handshakeComplete.CustomerMessage != "" { - fmt.Fprintln(os.Stdout, handshakeComplete.CustomerMessage) + fmt.Fprintln(dataChannel.Out, handshakeComplete.CustomerMessage) } return err @@ -783,9 +786,9 @@ func (dataChannel DataChannel) HandleChannelClosedMessage(log log.T, stopHandler log.Infof("Exiting session with sessionId: %s with output: %s", sessionId, channelClosedMessage.Output) if channelClosedMessage.Output == "" { - fmt.Fprintf(os.Stdout, "\n\nExiting session with sessionId: %s.\n\n", sessionId) + fmt.Fprintf(dataChannel.Out, "\n\nExiting session with sessionId: %s.\n\n", sessionId) } else { - fmt.Fprintf(os.Stdout, "\n\nSessionId: %s : %s\n\n", sessionId, channelClosedMessage.Output) + fmt.Fprintf(dataChannel.Out, "\n\nSessionId: %s : %s\n\n", sessionId, channelClosedMessage.Output) } stopHandler() diff --git a/src/sessionmanagerplugin-main/main.go b/src/sessionmanagerplugin-main/main.go index 49a11c08..0a02bec5 100644 --- a/src/sessionmanagerplugin-main/main.go +++ b/src/sessionmanagerplugin-main/main.go @@ -22,6 +22,14 @@ import ( _ "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session/shellsession" ) +var out = os.Stdout + +func init() { + if quietStr := os.Getenv("AWS_SSM_QUIET"); quietStr == "true" || quietStr == "1" { + out = os.Stderr + } +} + func main() { - session.ValidateInputAndStartSession(os.Args, os.Stdout) + session.ValidateInputAndStartSession(os.Args, out) } diff --git a/src/sessionmanagerplugin/session/portsession/basicportforwarding.go b/src/sessionmanagerplugin/session/portsession/basicportforwarding.go index 65f1057c..63f92a7d 100644 --- a/src/sessionmanagerplugin/session/portsession/basicportforwarding.go +++ b/src/sessionmanagerplugin/session/portsession/basicportforwarding.go @@ -16,6 +16,7 @@ package portsession import ( "fmt" + "io" "net" "os" "os/signal" @@ -39,6 +40,7 @@ type BasicPortForwarding struct { sessionId string portParameters PortParameters session session.Session + out io.Writer } // getNewListener returns a new listener to given address and type like tcp, unix etc. @@ -132,7 +134,7 @@ func (p *BasicPortForwarding) startLocalConn(log log.T) (err error) { return err } log.Infof("Connection accepted for session %s.", p.sessionId) - fmt.Printf("Connection accepted for session %s.\n", p.sessionId) + fmt.Fprintf(p.out, "Connection accepted for session %s.\n", p.sessionId) p.listener = &listener p.stream = &tcpConn @@ -159,7 +161,7 @@ func (p *BasicPortForwarding) startLocalListener(log log.T, portNumber string) ( } log.Info(displayMessage) - fmt.Println(displayMessage) + fmt.Fprintln(p.out, displayMessage) return } @@ -169,13 +171,13 @@ func (p *BasicPortForwarding) handleControlSignals(log log.T) { signal.Notify(c, sessionutil.ControlSignals...) go func() { <-c - fmt.Println("Terminate signal received, exiting.") + fmt.Fprintln(p.out, "Terminate signal received, exiting.") if version.DoesAgentSupportTerminateSessionFlag(log, p.session.DataChannel.GetAgentVersion()) { if err := p.session.DataChannel.SendFlag(log, message.TerminateSession); err != nil { log.Errorf("Failed to send TerminateSession flag: %v", err) } - fmt.Fprintf(os.Stdout, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId) + fmt.Fprintf(p.out, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId) p.Stop() } else { p.session.TerminateSession(log) diff --git a/src/sessionmanagerplugin/session/portsession/basicportforwarding_test.go b/src/sessionmanagerplugin/session/portsession/basicportforwarding_test.go index ae1bf8de..a5593392 100644 --- a/src/sessionmanagerplugin/session/portsession/basicportforwarding_test.go +++ b/src/sessionmanagerplugin/session/portsession/basicportforwarding_test.go @@ -48,6 +48,7 @@ func TestSetSessionHandlers(t *testing.T) { Session: mockSession, portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"}, portSessionType: &BasicPortForwarding{ + out: os.Stdout, session: mockSession, portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"}, }, @@ -84,6 +85,7 @@ func TestStartSessionTCPLocalPortFromDocument(t *testing.T) { Session: getSessionMock(), portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding", LocalPortNumber: "54321"}, portSessionType: &BasicPortForwarding{ + out: os.Stdout, session: getSessionMock(), portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"}, }, @@ -101,6 +103,7 @@ func TestStartSessionTCPAcceptFailed(t *testing.T) { Session: getSessionMock(), portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"}, portSessionType: &BasicPortForwarding{ + out: os.Stdout, session: getSessionMock(), portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"}, }, @@ -117,6 +120,7 @@ func TestStartSessionTCPConnectFailed(t *testing.T) { Session: getSessionMock(), portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"}, portSessionType: &BasicPortForwarding{ + out: os.Stdout, session: getSessionMock(), portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"}, }, diff --git a/src/sessionmanagerplugin/session/portsession/muxportforwarding.go b/src/sessionmanagerplugin/session/portsession/muxportforwarding.go index 85fe0ce6..52488b41 100644 --- a/src/sessionmanagerplugin/session/portsession/muxportforwarding.go +++ b/src/sessionmanagerplugin/session/portsession/muxportforwarding.go @@ -61,6 +61,7 @@ type MuxPortForwarding struct { session session.Session muxClient *MuxClient mgsConn *MgsConn + out io.Writer } func (c *MgsConn) close() { @@ -131,7 +132,7 @@ func (p *MuxPortForwarding) WriteStream(outputMessage message.ClientMessage) err binary.Read(buf, binary.BigEndian, &flag) if message.ConnectToPortError == flag { - fmt.Printf("\nConnection to destination port failed, check SSM Agent logs.\n") + fmt.Fprintf(p.out, "\nConnection to destination port failed, check SSM Agent logs.\n") } } return nil @@ -190,12 +191,12 @@ func (p *MuxPortForwarding) handleControlSignals(log log.T) { signal.Notify(c, sessionutil.ControlSignals...) go func() { <-c - fmt.Println("Terminate signal received, exiting.") + fmt.Fprintln(p.out, "Terminate signal received, exiting.") if err := p.session.DataChannel.SendFlag(log, message.TerminateSession); err != nil { log.Errorf("Failed to send TerminateSession flag: %v", err) } - fmt.Fprintf(os.Stdout, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId) + fmt.Fprintf(p.out, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId) p.Stop() }() } @@ -252,10 +253,10 @@ func (p *MuxPortForwarding) handleClientConnections(log log.T, ctx context.Conte defer listener.Close() log.Infof(displayMsg) - fmt.Printf(displayMsg) + fmt.Fprintf(p.out, displayMsg) log.Infof("Waiting for connections...\n") - fmt.Printf("\nWaiting for connections...\n") + fmt.Fprintf(p.out, "\nWaiting for connections...\n") var once sync.Once for { @@ -269,7 +270,7 @@ func (p *MuxPortForwarding) handleClientConnections(log log.T, ctx context.Conte log.Infof("Connection accepted from %s\n for session [%s]", conn.RemoteAddr(), p.sessionId) once.Do(func() { - fmt.Printf("\nConnection accepted for session [%s]\n", p.sessionId) + fmt.Fprintf(p.out, "\nConnection accepted for session [%s]\n", p.sessionId) }) stream, err := p.muxClient.session.OpenStream() diff --git a/src/sessionmanagerplugin/session/portsession/portsession.go b/src/sessionmanagerplugin/session/portsession/portsession.go index 793b6a74..454b1129 100644 --- a/src/sessionmanagerplugin/session/portsession/portsession.go +++ b/src/sessionmanagerplugin/session/portsession/portsession.go @@ -15,6 +15,8 @@ package portsession import ( + "os" + "github.com/aws/session-manager-plugin/src/config" "github.com/aws/session-manager-plugin/src/jsonutil" "github.com/aws/session-manager-plugin/src/log" @@ -70,12 +72,14 @@ func (s *PortSession) Initialize(log log.T, sessionVar *session.Session) { sessionId: s.SessionId, portParameters: s.portParameters, session: s.Session, + out: os.Stdout, } } else { s.portSessionType = &BasicPortForwarding{ sessionId: s.SessionId, portParameters: s.portParameters, session: s.Session, + out: os.Stdout, } } } else { diff --git a/src/sessionmanagerplugin/session/portsession/test_portsession.go b/src/sessionmanagerplugin/session/portsession/test_portsession.go index 4d9cf826..ce6e0e94 100644 --- a/src/sessionmanagerplugin/session/portsession/test_portsession.go +++ b/src/sessionmanagerplugin/session/portsession/test_portsession.go @@ -15,6 +15,8 @@ package portsession import ( + "os" + "github.com/aws/session-manager-plugin/src/communicator/mocks" "github.com/aws/session-manager-plugin/src/datachannel" "github.com/aws/session-manager-plugin/src/log" @@ -41,11 +43,13 @@ func getSessionMock() session.Session { } func getSessionMockWithParams(properties interface{}, agentVersion string) session.Session { - datachannel := &datachannel.DataChannel{} + out := os.Stdout + datachannel := &datachannel.DataChannel{Out: out} datachannel.SetAgentVersion(agentVersion) var mockSession = session.Session{ DataChannel: datachannel, + Out: out, } mockSession.DataChannel.Initialize(mockLog, "clientId", "sessionId", "targetId", false) diff --git a/src/sessionmanagerplugin/session/session.go b/src/sessionmanagerplugin/session/session.go index 567a35dc..356694bd 100644 --- a/src/sessionmanagerplugin/session/session.go +++ b/src/sessionmanagerplugin/session/session.go @@ -85,14 +85,15 @@ type Session struct { SessionType string SessionProperties interface{} DisplayMode sessionutil.DisplayMode + Out io.Writer } -//startSession create the datachannel for session +// startSession create the datachannel for session var startSession = func(session *Session, log log.T) error { return session.Execute(log) } -//setSessionHandlersWithSessionType set session handlers based on session subtype +// setSessionHandlersWithSessionType set session handlers based on session subtype var setSessionHandlersWithSessionType = func(session *Session, log log.T) error { // SessionType is set inside DataChannel sessionSubType := SessionRegistry[session.SessionType] @@ -203,7 +204,8 @@ func ValidateInputAndStartSession(args []string, out io.Writer) { session.Endpoint = ssmEndpoint session.ClientId = clientId session.TargetId = target - session.DataChannel = &datachannel.DataChannel{} + session.DataChannel = &datachannel.DataChannel{Out: out} + session.Out = out default: fmt.Fprint(out, "Invalid Operation") @@ -217,12 +219,12 @@ func ValidateInputAndStartSession(args []string, out io.Writer) { } } -//Execute create data channel and start the session +// Execute create data channel and start the session func (s *Session) Execute(log log.T) (err error) { - fmt.Fprintf(os.Stdout, "\nStarting session with SessionId: %s\n", s.SessionId) + fmt.Fprintf(s.Out, "\nStarting session with SessionId: %s\n", s.SessionId) // sets the display mode - s.DisplayMode = sessionutil.NewDisplayMode(log) + s.DisplayMode = sessionutil.NewDisplayMode(log, s.Out) if err = s.OpenDataChannel(log); err != nil { log.Errorf("Error in Opening data channel: %v", err) diff --git a/src/sessionmanagerplugin/session/session_test.go b/src/sessionmanagerplugin/session/session_test.go index 50f9f252..097d5713 100644 --- a/src/sessionmanagerplugin/session/session_test.go +++ b/src/sessionmanagerplugin/session/session_test.go @@ -104,7 +104,7 @@ func TestValidateInputAndStartSessionWithWrongEnvVariableName(t *testing.T) { } func TestExecute(t *testing.T) { - sessionMock := &Session{} + sessionMock := &Session{Out: os.Stdout} sessionMock.DataChannel = mockDataChannel SetupMockActions() mockDataChannel.On("Open", mock.Anything).Return(nil) @@ -128,7 +128,7 @@ func TestExecute(t *testing.T) { } func TestExecuteAndStreamMessageResendTimesOut(t *testing.T) { - sessionMock := &Session{} + sessionMock := &Session{Out: os.Stdout} sessionMock.DataChannel = mockDataChannel SetupMockActions() mockDataChannel.On("Open", mock.Anything).Return(nil) diff --git a/src/sessionmanagerplugin/session/sessionhandler.go b/src/sessionmanagerplugin/session/sessionhandler.go index cab3cfa7..210b1d37 100644 --- a/src/sessionmanagerplugin/session/sessionhandler.go +++ b/src/sessionmanagerplugin/session/sessionhandler.go @@ -129,7 +129,7 @@ func (s *Session) ResumeSessionHandler(log log.T) (err error) { return } else if s.TokenValue == "" { log.Debugf("Session: %s timed out", s.SessionId) - fmt.Fprintf(os.Stdout, "Session: %s timed out.\n", s.SessionId) + fmt.Fprintf(s.Out, "Session: %s timed out.\n", s.SessionId) os.Exit(0) } s.DataChannel.GetWsChannel().SetChannelToken(s.TokenValue) diff --git a/src/sessionmanagerplugin/session/sessionhandler_test.go b/src/sessionmanagerplugin/session/sessionhandler_test.go index ba14b402..06054ef8 100644 --- a/src/sessionmanagerplugin/session/sessionhandler_test.go +++ b/src/sessionmanagerplugin/session/sessionhandler_test.go @@ -16,6 +16,7 @@ package session import ( "fmt" + "os" "testing" wsChannelMock "github.com/aws/session-manager-plugin/src/communicator/mocks" @@ -23,6 +24,7 @@ import ( "github.com/aws/session-manager-plugin/src/datachannel" dataChannelMock "github.com/aws/session-manager-plugin/src/datachannel/mocks" "github.com/aws/session-manager-plugin/src/message" + "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session/sessionutil" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/assert" @@ -38,7 +40,7 @@ func TestOpenDataChannel(t *testing.T) { mockDataChannel = &dataChannelMock.IDataChannel{} mockWsChannel = &wsChannelMock.IWebSocketChannel{} - sessionMock := &Session{} + sessionMock := &Session{Out: os.Stdout} sessionMock.DataChannel = mockDataChannel SetupMockActions() mockDataChannel.On("Open", mock.Anything).Return(nil) @@ -51,7 +53,7 @@ func TestOpenDataChannelWithError(t *testing.T) { mockDataChannel = &dataChannelMock.IDataChannel{} mockWsChannel = &wsChannelMock.IWebSocketChannel{} - sessionMock := &Session{} + sessionMock := &Session{Out: os.Stdout} sessionMock.DataChannel = mockDataChannel SetupMockActions() @@ -69,10 +71,12 @@ func TestProcessFirstMessageOutputMessageFirst(t *testing.T) { Payload: []byte("testing"), } - dataChannel := &datachannel.DataChannel{} + dataChannel := &datachannel.DataChannel{Out: os.Stdout} dataChannel.Initialize(logger, clientId, sessionId, instanceId, false) session := Session{ + Out: os.Stdout, DataChannel: dataChannel, + DisplayMode: sessionutil.NewDisplayMode(logger, os.Stdout), } session.ProcessFirstMessage(logger, outputMessage) diff --git a/src/sessionmanagerplugin/session/sessionutil/sessionutil.go b/src/sessionmanagerplugin/session/sessionutil/sessionutil.go index b6bd4db2..b69abe2c 100644 --- a/src/sessionmanagerplugin/session/sessionutil/sessionutil.go +++ b/src/sessionmanagerplugin/session/sessionutil/sessionutil.go @@ -14,10 +14,14 @@ // Package sessionutil provides utility for sessions. package sessionutil -import "github.com/aws/session-manager-plugin/src/log" +import ( + "io" -func NewDisplayMode(log log.T) DisplayMode { - displayMode := DisplayMode{} + "github.com/aws/session-manager-plugin/src/log" +) + +func NewDisplayMode(log log.T, out io.Writer) DisplayMode { + displayMode := DisplayMode{out: out} displayMode.InitDisplayMode(log) return displayMode } diff --git a/src/sessionmanagerplugin/session/sessionutil/sessionutil_unix.go b/src/sessionmanagerplugin/session/sessionutil/sessionutil_unix.go index 023915fa..8b78d261 100644 --- a/src/sessionmanagerplugin/session/sessionutil/sessionutil_unix.go +++ b/src/sessionmanagerplugin/session/sessionutil/sessionutil_unix.go @@ -21,13 +21,13 @@ import ( "fmt" "io" "net" - "os" "github.com/aws/session-manager-plugin/src/log" "github.com/aws/session-manager-plugin/src/message" ) type DisplayMode struct { + out io.Writer } func (d *DisplayMode) InitDisplayMode(log log.T) { @@ -35,8 +35,7 @@ func (d *DisplayMode) InitDisplayMode(log log.T) { // DisplayMessage function displays the output on the screen func (d *DisplayMode) DisplayMessage(log log.T, message message.ClientMessage) { - var out io.Writer = os.Stdout - fmt.Fprint(out, string(message.Payload)) + fmt.Fprint(d.out, string(message.Payload)) } // NewListener starts a new socket listener on the address. diff --git a/src/sessionmanagerplugin/session/sessionutil/sessionutil_windows.go b/src/sessionmanagerplugin/session/sessionutil/sessionutil_windows.go index 95dc2775..a70a7011 100644 --- a/src/sessionmanagerplugin/session/sessionutil/sessionutil_windows.go +++ b/src/sessionmanagerplugin/session/sessionutil/sessionutil_windows.go @@ -19,6 +19,7 @@ package sessionutil import ( "fmt" + "io" "net" "os" "syscall" @@ -32,6 +33,7 @@ var EnvProgramFiles = os.Getenv("ProgramFiles") type DisplayMode struct { handle windows.Handle + out io.Writer } func (d *DisplayMode) InitDisplayMode(log log.T) { @@ -71,7 +73,7 @@ func (d *DisplayMode) DisplayMessage(log log.T, message message.ClientMessage) { // refer - https://docs.microsoft.com/en-us/windows/desktop/api/fileapi/nf-fileapi-writefile if err = windows.WriteFile(d.handle, message.Payload, done, nil); err != nil { log.Errorf("error occurred while writing to file: %v", err) - fmt.Fprintf(os.Stdout, "\nError getting the output. %s\n", err.Error()) + fmt.Fprintf(d.out, "\nError getting the output. %s\n", err.Error()) os.Exit(0) } } diff --git a/src/sessionmanagerplugin/session/shellsession/shellsession_test.go b/src/sessionmanagerplugin/session/shellsession/shellsession_test.go index 51304f09..89bf067f 100644 --- a/src/sessionmanagerplugin/session/shellsession/shellsession_test.go +++ b/src/sessionmanagerplugin/session/shellsession/shellsession_test.go @@ -52,7 +52,7 @@ func TestName(t *testing.T) { } func TestInitialize(t *testing.T) { - session := &session.Session{} + session := &session.Session{Out: os.Stdout} shellSession := ShellSession{} session.DataChannel = mockDataChannel mockDataChannel.On("RegisterOutputStreamHandler", mock.Anything, true).Times(1) @@ -63,7 +63,7 @@ func TestInitialize(t *testing.T) { } func TestHandleControlSignals(t *testing.T) { - session := session.Session{} + session := session.Session{Out: os.Stdout} session.DataChannel = mockDataChannel shellSession := ShellSession{} shellSession.Session = session @@ -154,7 +154,7 @@ func TestTerminalResizeWhenSessionSizeDataIsNotEqualToActualSize(t *testing.T) { func TestProcessStreamMessagePayload(t *testing.T) { shellSession := ShellSession{} - shellSession.DisplayMode = sessionutil.NewDisplayMode(logger) + shellSession.DisplayMode = sessionutil.NewDisplayMode(logger, os.Stdout) msg := message.ClientMessage{ Payload: []byte("Hello Agent\n"), @@ -165,7 +165,7 @@ func TestProcessStreamMessagePayload(t *testing.T) { } func getDataChannel() *datachannel.DataChannel { - dataChannel := &datachannel.DataChannel{} + dataChannel := &datachannel.DataChannel{Out: os.Stdout} dataChannel.Initialize(logger, clientId, sessionId, instanceId, false) dataChannel.SetWsChannel(mockWsChannel) return dataChannel diff --git a/src/ssmclicommands/inputhandler.go b/src/ssmclicommands/inputhandler.go index 70c4b25f..34861d78 100644 --- a/src/ssmclicommands/inputhandler.go +++ b/src/ssmclicommands/inputhandler.go @@ -106,7 +106,7 @@ func ValidateInput(args []string, out io.Writer) { if utils.IsHelp(subcommand, parameters) { fmt.Fprintln(out, cmd.Help()) } else { - cmdErr, result := cmd.Execute(parameters) + cmdErr, result := cmd.Execute(out, parameters) if cmdErr != nil { utils.DisplayCommandUsage(out) fmt.Fprint(out, cmdErr.Error()) diff --git a/src/ssmclicommands/startsession.go b/src/ssmclicommands/startsession.go index d37edb7d..73da6086 100644 --- a/src/ssmclicommands/startsession.go +++ b/src/ssmclicommands/startsession.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "html/template" + "io" "strings" sdkSession "github.com/aws/aws-sdk-go/aws/session" @@ -88,7 +89,7 @@ type StartSessionCommand struct { sdk *ssm.SSM } -//getSSMClient generate ssm client by configuration +// getSSMClient generate ssm client by configuration var getSSMClient = func(log log.T, region string, profile string, endpoint string) (*ssm.SSM, error) { sdkutil.SetRegionAndProfile(region, profile) @@ -101,7 +102,7 @@ var getSSMClient = func(log log.T, region string, profile string, endpoint strin return ssm.New(sdkSession), nil } -//executeSession to open datachannel +// executeSession to open datachannel var executeSession = func(log log.T, session *session.Session) (err error) { return session.Execute(log) } @@ -141,8 +142,8 @@ func (c *StartSessionCommand) Help() string { return c.helpText } -//validates and execute start-session command -func (s *StartSessionCommand) Execute(parameters map[string][]string) (error, string) { +// validates and execute start-session command +func (s *StartSessionCommand) Execute(out io.Writer, parameters map[string][]string) (error, string) { var ( err error region string @@ -191,7 +192,8 @@ func (s *StartSessionCommand) Execute(parameters map[string][]string) (error, st Endpoint: endpoint, ClientId: clientId, TargetId: instanceId, - DataChannel: &datachannel.DataChannel{}, + DataChannel: &datachannel.DataChannel{Out: out}, + Out: out, } if err = executeSession(log, &session); err != nil { @@ -202,7 +204,7 @@ func (s *StartSessionCommand) Execute(parameters map[string][]string) (error, st return err, "StartSession executed successfully" } -//func to validate start-session input +// func to validate start-session input func (StartSessionCommand) validateStartSessionInput(parameters map[string][]string) []string { validation := make([]string, 0) diff --git a/src/ssmclicommands/startsession_test.go b/src/ssmclicommands/startsession_test.go index a62718e4..e02f8830 100644 --- a/src/ssmclicommands/startsession_test.go +++ b/src/ssmclicommands/startsession_test.go @@ -16,6 +16,7 @@ package ssmclicommands import ( "fmt" + "os" "testing" "github.com/aws/aws-sdk-go/service/ssm" @@ -75,7 +76,7 @@ func TestStartSessionCommand_ExecuteSuccess(t *testing.T) { return startSessionOutput, nil } - err, msg := command.Execute(parameter) + err, msg := command.Execute(os.Stdout, parameter) assert.Nil(t, err) assert.Equal(t, msg, "StartSession executed successfully") } @@ -91,7 +92,7 @@ func TestStartSessionCommand_ExecuteGetSSMClientFailure(t *testing.T) { return nil, fmt.Errorf("Get SSMClient Failure") } - err, msg := command.Execute(parameter) + err, msg := command.Execute(os.Stdout, parameter) assert.NotNil(t, err) assert.Equal(t, err.Error(), "Get SSMClient Failure") assert.Equal(t, msg, "StartSession failed") @@ -115,7 +116,7 @@ func TestStartSessionCommand_ExecuteSessionFailure(t *testing.T) { return startSessionOutput, nil } - err, msg := command.Execute(parameter) + err, msg := command.Execute(os.Stdout, parameter) assert.NotNil(t, err) assert.Equal(t, err.Error(), "Execute Session Failure") assert.Equal(t, msg, "StartSession failed") diff --git a/src/ssmclicommands/utils/util.go b/src/ssmclicommands/utils/util.go index 339a940b..cc5a13b9 100644 --- a/src/ssmclicommands/utils/util.go +++ b/src/ssmclicommands/utils/util.go @@ -16,6 +16,7 @@ package utils import ( "fmt" + "io" "strings" ) @@ -30,7 +31,7 @@ var SsmCliCommands map[string]SsmCliCommand // CliCommand defines the interface for all commands the cli can execute type SsmCliCommand interface { - Execute(parameters map[string][]string) (error, string) + Execute(out io.Writer, parameters map[string][]string) (error, string) Help() string Name() string }