diff --git a/cmd/suseconnect/suseconnect.go b/cmd/suseconnect/suseconnect.go index 0d21157e..b5cc969e 100644 --- a/cmd/suseconnect/suseconnect.go +++ b/cmd/suseconnect/suseconnect.go @@ -1,21 +1,22 @@ package main import ( + "bufio" _ "embed" "encoding/json" "errors" "flag" "fmt" + "github.com/SUSE/connect-ng/internal/connect" + "github.com/SUSE/connect-ng/internal/util" + "github.com/SUSE/connect-ng/internal/zypper" + "io" "net/http" "net/url" "os" "runtime" "strings" "syscall" - - "github.com/SUSE/connect-ng/internal/connect" - "github.com/SUSE/connect-ng/internal/util" - "github.com/SUSE/connect-ng/internal/zypper" ) var ( @@ -150,9 +151,7 @@ func main() { connect.CFG.Namespace = namespace writeConfig = true } - if token != "" { - connect.CFG.Token = token - } + parseRegistrationToken(token) if product.isSet { if p, err := connect.SplitTriplet(product.value); err != nil { fmt.Print("Please provide the product identifier in this format: ") @@ -316,6 +315,18 @@ func main() { } } +func parseRegistrationToken(token string) { + if token != "" { + connect.CFG.Token = token + processedToken, processTokenErr := processToken(token) + if processTokenErr != nil { + fmt.Printf("Error Processing token with error %+v", processTokenErr) + os.Exit(1) + } + connect.CFG.Token = processedToken + } +} + func maybeBrokenSMTError() error { if !connect.CFG.IsScc() && !connect.UpToDate() { return fmt.Errorf("Your Registration Proxy server doesn't support this function. " + @@ -403,3 +414,32 @@ func fileExists(path string) bool { func isSumaManaged() bool { return fileExists("/etc/sysconfig/rhn/systemid") } + +func processToken(token string) (string, error) { + if strings.HasPrefix(token, "@") { + tokenFilePath := strings.TrimPrefix(token, "@") + file, err := os.Open(tokenFilePath) + if err != nil { + return "", fmt.Errorf("failed to open token file '%s': %w", tokenFilePath, err) + } + defer file.Close() + return readTokenFromReader(file) + } else if token == "-" { + return readTokenFromReader(os.Stdin) + } else { + return token, nil + } +} + +func readTokenFromReader(reader io.Reader) (string, error) { + bufReader := bufio.NewReader(reader) + tokenBytes, err := bufReader.ReadString('\n') + if err != nil && err != io.EOF { + return "", fmt.Errorf("failed to read token from reader: %w", err) + } + token := strings.TrimSpace(tokenBytes) + if token == "" { + return "", fmt.Errorf("error: token cannot be empty after reading") + } + return token, nil +} diff --git a/cmd/suseconnect/suseconnect_test.go b/cmd/suseconnect/suseconnect_test.go new file mode 100644 index 00000000..ab7e02f7 --- /dev/null +++ b/cmd/suseconnect/suseconnect_test.go @@ -0,0 +1,280 @@ +package main + +import ( + "errors" + "fmt" + "github.com/SUSE/connect-ng/internal/connect" + "github.com/SUSE/connect-ng/internal/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "os" + "strings" + "testing" +) + +var processTokenFunc = processToken + +var exitCalled bool +var exit = func(code int) { + exitCalled = true +} + +type MockProcessToken struct { + mock.Mock +} + +func (m *MockProcessToken) ProcessToken(token string) (string, error) { + args := m.Called(token) + return args.String(0), args.Error(1) +} + +func init() { + processTokenFunc = processToken +} + +type errorReader struct{} + +func (e *errorReader) Read(p []byte) (n int, err error) { + return 0, fmt.Errorf("forced reader error") +} + +func TestReadTokenFromErrorValidToken(t *testing.T) { + inputToken := "validToken\n" + reader := strings.NewReader(inputToken) + token, err := readTokenFromReader(reader) + if err != nil { + t.Fatalf("Expected no error but got %v", err) + } + if token != "validToken" { + t.Fatalf("Expected token string to be 'validToken' but got '%s'", token) + } +} + +func TestReadTokenFromReader_MultipleNewlines(t *testing.T) { + input := "firstToken\nsecondToken\n" + reader := strings.NewReader(input) + + token, err := readTokenFromReader(reader) + if err != nil { + t.Fatalf("Expected no error, but got %v", err) + } + + expected := "firstToken" + if token != expected { + t.Errorf("Expected token to be '%s', but got '%s'", expected, token) + } +} + +func TestReadTokenFromReader_EmptyInput(t *testing.T) { + reader := strings.NewReader("") + + token, err := readTokenFromReader(reader) + if err == nil { + t.Fatalf("Expected error, but got none") + } + + expectedError := "error: token cannot be empty after reading" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', but got '%s'", expectedError, err.Error()) + } + + if token != "" { + t.Errorf("Expected empty token, but got '%s'", token) + } +} + +func TestReadTokenFromReader_OnlyNewline(t *testing.T) { + reader := strings.NewReader("\n") + + token, err := readTokenFromReader(reader) + if err == nil { + t.Fatalf("Expected error, but got none") + } + + expectedError := "error: token cannot be empty after reading" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', but got '%s'", expectedError, err.Error()) + } + + if token != "" { + t.Errorf("Expected empty token, but got '%s'", token) + } +} + +func TestReadTokenFromReader_ErrorProducingReader(t *testing.T) { + reader := &errorReader{} + + token, err := readTokenFromReader(reader) + if err == nil { + t.Fatalf("Expected error, but got none") + } + + expectedError := "failed to read token from reader: forced reader error" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', but got '%s'", expectedError, err.Error()) + } + + if token != "" { + t.Errorf("Expected empty token, but got '%s'", token) + } +} + +func TestProcessToken_RegularToken(t *testing.T) { + token := "myRegularToken" + result, err := processToken(token) + if err != nil { + t.Fatalf("Expected no error, but got %v", err) + } + + if result != token { + t.Errorf("Expected token to be '%s', but got '%s'", token, result) + } +} + +func TestProcessToken_TokenFromFile(t *testing.T) { + tmpFile, err := os.CreateTemp("", "tokenfile") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + + expectedToken := "fileToken\n" + if _, err := tmpFile.WriteString(expectedToken); err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + tmpFile.Close() + + token := "@" + tmpFile.Name() + result, err := processToken(token) + if err != nil { + t.Fatalf("Expected no error, but got %v", err) + } + + expectedToken = strings.TrimSpace(expectedToken) + if result != expectedToken { + t.Errorf("Expected token to be '%s', but got '%s'", expectedToken, result) + } +} + +func TestProcessToken_NonExistentFile(t *testing.T) { + token := "@/non/existent/file" + _, err := processToken(token) + if err == nil { + t.Fatalf("Expected error for non-existent file, but got none") + } + + expectedError := "failed to open token file '/non/existent/file'" + if !strings.Contains(err.Error(), expectedError) { + t.Errorf("Expected error containing '%s', but got '%v'", expectedError, err) + } +} + +func TestProcessToken_TokenFromStdin(t *testing.T) { + tempFile, err := os.CreateTemp("", "test_stdin") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tempFile.Name()) + expectedToken := "stdinToken\n" + if _, err := tempFile.WriteString(expectedToken); err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + + tempFile.Close() + + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + file, err := os.Open(tempFile.Name()) + if err != nil { + t.Fatalf("Failed to open temp file: %v", err) + } + os.Stdin = file + + result, err := processToken("-") + if err != nil { + t.Fatalf("Expected no error, but got %v", err) + } + + expectedToken = strings.TrimSpace(expectedToken) + if result != expectedToken { + t.Errorf("Expected token to be '%s', but got '%s'", expectedToken, result) + } +} + +func TestProcessToken_ErrorReadingStdin(t *testing.T) { + tempFile, err := os.CreateTemp("", "test_stdin_empty") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tempFile.Name()) + + tempFile.Close() + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + file, err := os.Open(tempFile.Name()) + if err != nil { + t.Fatalf("Failed to open temp file: %v", err) + } + os.Stdin = file + + _, err = processToken("-") + if err == nil { + t.Fatalf("Expected error reading from stdin, but got none") + } + + expectedError := "error: token cannot be empty after reading" + if !strings.Contains(err.Error(), expectedError) { + t.Errorf("Expected error containing '%s', but got '%v'", expectedError, err) + } +} + +func parseRegistrationTokenWithInjection(token string) { + if token != "" { + connect.CFG.Token = token + processedToken, processTokenErr := processTokenFunc(token) + if processTokenErr != nil { + util.Debug.Printf("Error Processing token %+v", processTokenErr) + exit(1) + } + connect.CFG.Token = processedToken + } +} + +func TestParseToken_Success(t *testing.T) { + mockProcessToken := new(MockProcessToken) + mockProcessToken.On("ProcessToken", "valid-token").Return("processed-token", nil) + + processTokenFunc = mockProcessToken.ProcessToken + + exitCalled = false + + parseRegistrationTokenWithInjection("valid-token") + + assert.Equal(t, "processed-token", connect.CFG.Token, "Token should be processed correctly") + assert.False(t, exitCalled, "os.Exit (simulated) should not be called in a successful case") + + mockProcessToken.AssertExpectations(t) +} + +func TestParseToken_ProcessTokenError(t *testing.T) { + mockProcessToken := new(MockProcessToken) + mockProcessToken.On("ProcessToken", "invalid-token").Return("", errors.New("failed to process token")) + + processTokenFunc = mockProcessToken.ProcessToken + + exitCalled = false + + parseRegistrationTokenWithInjection("invalid-token") + + assert.True(t, exitCalled, "os.Exit (simulated) should be called when processToken fails") + assert.Equal(t, "", connect.CFG.Token, "Token should not be updated when processToken fails") + + mockProcessToken.AssertExpectations(t) +} + +func TestParseToken_EmptyToken(t *testing.T) { + exitCalled = false + parseRegistrationTokenWithInjection("") + assert.Empty(t, connect.CFG.Token, "Token should not be updated when input token is empty") + assert.False(t, exitCalled, "os.Exit (simulated) should not be called for empty token") +}